Skip to main content

matrixcode_core/agent/
streaming.rs

1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep};
5
6use crate::event::AgentEvent;
7use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
8
9use super::types::Agent;
10
11/// Wait for cancellation signal, checking periodically.
12async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
13    while !token.is_cancelled() {
14        sleep(Duration::from_millis(100)).await;
15    }
16}
17
18impl Agent {
19    /// Drain any pending input messages from the channel.
20    /// Called during streaming to collect real-time appended messages.
21    pub(crate) fn drain_pending_inputs(&mut self) {
22        if let Some(rx) = &mut self.pending_input_rx {
23            while let Ok(msg) = rx.try_recv() {
24                log::info!(
25                    "Agent received pending input: {}",
26                    msg.chars().take(50).collect::<String>()
27                );
28                self.pending_inputs.push(msg);
29            }
30        }
31    }
32
33    /// Check if there are pending inputs waiting to be processed.
34    pub fn has_pending_inputs(&self) -> bool {
35        !self.pending_inputs.is_empty()
36    }
37
38    /// Get and clear all pending inputs.
39    pub fn take_pending_inputs(&mut self) -> Vec<String> {
40        let inputs = self.pending_inputs.clone();
41        self.pending_inputs.clear();
42        inputs
43    }
44
45    /// Call provider with streaming and emit events in real-time.
46    /// Also monitors pending_input_rx for real-time message appending.
47    pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
48        const MAX_RETRIES: u32 = 5;
49        const RETRY_DELAY_MS: u64 = 1000;
50
51        let mut attempt = 0;
52
53        loop {
54            attempt += 1;
55            log::info!(
56                "Agent: API call attempt {} with {} messages",
57                attempt,
58                request.messages.len()
59            );
60
61            if let Some(token) = &self.cancel_token
62                && token.is_cancelled()
63            {
64                return Err(anyhow::anyhow!("Operation cancelled"));
65            }
66
67            log::info!("Agent: calling provider.chat_stream");
68            let rx_result = self.provider.chat_stream(request.clone()).await;
69            log::info!("Agent: provider.chat_stream returned");
70
71            match rx_result {
72                Ok(mut rx) => {
73                    let mut response_content: Vec<ContentBlock> = Vec::new();
74                    let mut current_text = String::new();
75                    let mut current_thinking = String::new();
76                    let mut usage = Usage {
77                        input_tokens: 0,
78                        output_tokens: 0,
79                        cache_creation_input_tokens: 0,
80                        cache_read_input_tokens: 0,
81                    };
82                    let mut should_retry = false;
83
84                    loop {
85                        // Use select! with cancellation and pending input checks
86                        let event = if let Some(token) = &self.cancel_token {
87                            tokio::select! {
88                                // Primary: receive stream event
89                                event = rx.recv() => event,
90                                // Check for pending inputs periodically
91                                _ = sleep(Duration::from_millis(50)) => {
92                                    self.drain_pending_inputs();
93                                    continue;
94                                }
95                                // Cancellation signal
96                                _ = wait_for_cancel_stream(token) => {
97                                    return Err(anyhow::anyhow!("Operation cancelled"));
98                                }
99                            }
100                        } else {
101                            // No cancellation token, but still check pending inputs
102                            tokio::select! {
103                                event = rx.recv() => event,
104                                _ = sleep(Duration::from_millis(50)) => {
105                                    self.drain_pending_inputs();
106                                    continue;
107                                }
108                            }
109                        };
110
111                        match event {
112                            None => break,
113                            Some(StreamEvent::FirstByte) => {}
114                            Some(StreamEvent::ThinkingDelta(delta)) => {
115                                // Check cancellation before emitting
116                                if let Some(token) = &self.cancel_token
117                                    && token.is_cancelled()
118                                {
119                                    return Err(anyhow::anyhow!("Operation cancelled"));
120                                }
121                                if current_thinking.is_empty() {
122                                    self.emit(AgentEvent::thinking_start())?;
123                                }
124                                current_thinking.push_str(&delta);
125                                self.emit(AgentEvent::thinking_delta(delta, None))?;
126                            }
127                            Some(StreamEvent::TextDelta(delta)) => {
128                                // Check cancellation before emitting
129                                if let Some(token) = &self.cancel_token
130                                    && token.is_cancelled()
131                                {
132                                    return Err(anyhow::anyhow!("Operation cancelled"));
133                                }
134                                if current_text.is_empty() {
135                                    self.emit(AgentEvent::text_start())?;
136                                }
137                                current_text.push_str(&delta);
138                                self.emit(AgentEvent::text_delta(delta))?;
139                            }
140                            Some(StreamEvent::ToolUseStart { id, name }) => {
141                                // Emit events for UI but don't push content blocks
142                                // Content will be added from Done event's resp.content
143                                if !current_thinking.is_empty() {
144                                    self.emit(AgentEvent::thinking_end())?;
145                                    current_thinking.clear();
146                                }
147                                if !current_text.is_empty() {
148                                    self.emit(AgentEvent::text_end())?;
149                                    current_text.clear();
150                                }
151                                self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
152                            }
153                            Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
154                            Some(StreamEvent::ToolInputComplete { id, name, input }) => {
155                                self.previewed_tool_inputs.insert(id.clone());
156                                self.emit(AgentEvent::tool_use_start(&id, &name, Some(input)))?;
157                            }
158                            Some(StreamEvent::Usage { output_tokens }) => {
159                                self.emit(AgentEvent::usage_with_cache(
160                                    0,
161                                    output_tokens as u64,
162                                    0,
163                                    0,
164                                ))?;
165                                usage.output_tokens = output_tokens;
166                            }
167                            Some(StreamEvent::Done(resp)) => {
168                                // Check cancellation before processing final response
169                                if let Some(token) = &self.cancel_token
170                                    && token.is_cancelled()
171                                {
172                                    return Err(anyhow::anyhow!("Operation cancelled"));
173                                }
174
175                                // Final drain of pending inputs before completing
176                                self.drain_pending_inputs();
177
178                                // IMPORTANT: Add current_thinking/current_text to response_content FIRST
179                                // before checking for duplicates from resp.content
180                                // This ensures all streamed content is preserved
181                                if !current_thinking.is_empty() {
182                                    self.emit(AgentEvent::thinking_end())?;
183                                    // Add to response_content with signature from resp if available
184                                    let signature = resp.content.iter()
185                                        .find_map(|b| {
186                                            if let ContentBlock::Thinking { thinking, signature } = b {
187                                                if thinking == &current_thinking {
188                                                    signature.clone()
189                                                } else {
190                                                    None
191                                                }
192                                            } else {
193                                                None
194                                            }
195                                        });
196                                    response_content.push(ContentBlock::Thinking {
197                                        thinking: current_thinking.clone(),
198                                        signature,
199                                    });
200                                    current_thinking.clear();
201                                }
202                                if !current_text.is_empty() {
203                                    self.emit(AgentEvent::text_end())?;
204                                    // Add to response_content
205                                    response_content.push(ContentBlock::Text {
206                                        text: current_text.clone(),
207                                    });
208                                    current_text.clear();
209                                }
210
211                                // Then add any additional blocks from final response that are NOT duplicates
212                                for block in &resp.content {
213                                    // Smart deduplication: compare content, not entire block
214                                    let is_duplicate = response_content.iter().any(|b| {
215                                        match (b, block) {
216                                            // For Thinking blocks, compare thinking content only (signature may differ)
217                                            (
218                                                ContentBlock::Thinking { thinking: t1, .. },
219                                                ContentBlock::Thinking { thinking: t2, .. },
220                                            ) => t1 == t2,
221                                            // For Text blocks, compare text content
222                                            (
223                                                ContentBlock::Text { text: t1 },
224                                                ContentBlock::Text { text: t2 },
225                                            ) => t1 == t2,
226                                            // For ToolUse, compare id
227                                            (
228                                                ContentBlock::ToolUse { id: id1, .. },
229                                                ContentBlock::ToolUse { id: id2, .. },
230                                            ) => id1 == id2,
231                                            // For ToolResult, compare tool_use_id
232                                            (
233                                                ContentBlock::ToolResult {
234                                                    tool_use_id: id1, ..
235                                                },
236                                                ContentBlock::ToolResult {
237                                                    tool_use_id: id2, ..
238                                                },
239                                            ) => id1 == id2,
240                                            // Default: exact comparison
241                                            _ => b == block,
242                                        }
243                                    });
244                                    if !is_duplicate {
245                                        response_content.push(block.clone());
246                                    }
247                                }
248                                usage = resp.usage;
249                            }
250                            Some(StreamEvent::Error(msg)) => {
251                                if attempt < MAX_RETRIES {
252                                    self.emit(AgentEvent::progress(
253                                        format!(
254                                            "⚠️ Stream error, retrying ({}/{}): {}",
255                                            attempt, MAX_RETRIES, &msg
256                                        ),
257                                        None,
258                                    ))?;
259                                    let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
260                                    tokio::time::sleep(tokio::time::Duration::from_millis(delay))
261                                        .await;
262                                    should_retry = true;
263                                    break;
264                                } else {
265                                    self.emit(AgentEvent::error(msg.clone(), None, None))?;
266                                    return Err(anyhow::anyhow!(
267                                        "Stream error after {} retries: {}",
268                                        MAX_RETRIES,
269                                        msg
270                                    ));
271                                }
272                            }
273                        }
274                    }
275
276                    if should_retry {
277                        continue;
278                    }
279
280                    return Ok(ChatResponse {
281                        content: response_content,
282                        stop_reason: StopReason::EndTurn,
283                        usage,
284                    });
285                }
286                Err(e) => {
287                    if attempt < MAX_RETRIES {
288                        let error_msg = e.to_string();
289                        self.emit(AgentEvent::progress(
290                            format!(
291                                "⚠️ API error, retrying ({}/{}): {}",
292                                attempt, MAX_RETRIES, &error_msg
293                            ),
294                            None,
295                        ))?;
296                        let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
297                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
298                    } else {
299                        return Err(anyhow::anyhow!(
300                            "API error after {} retries: {}",
301                            MAX_RETRIES,
302                            e
303                        ));
304                    }
305                }
306            }
307        }
308    }
309}