Skip to main content

matrixcode_core/agent/
streaming.rs

1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep};
5
6use crate::event::AgentEvent;
7use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
8
9use super::types::Agent;
10
11/// Wait for cancellation signal, checking periodically.
12async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
13    while !token.is_cancelled() {
14        sleep(Duration::from_millis(100)).await;
15    }
16}
17
18impl Agent {
19    /// Drain any pending input messages from the channel.
20    /// Called during streaming to collect real-time appended messages.
21    pub(crate) fn drain_pending_inputs(&mut self) {
22        if let Some(rx) = &mut self.pending_input_rx {
23            while let Ok(msg) = rx.try_recv() {
24                log::info!(
25                    "Agent received pending input: {}",
26                    msg.chars().take(50).collect::<String>()
27                );
28                self.pending_inputs.push(msg);
29            }
30        }
31    }
32
33    /// Check if there are pending inputs waiting to be processed.
34    pub fn has_pending_inputs(&self) -> bool {
35        !self.pending_inputs.is_empty()
36    }
37
38    /// Get and clear all pending inputs.
39    pub fn take_pending_inputs(&mut self) -> Vec<String> {
40        let inputs = self.pending_inputs.clone();
41        self.pending_inputs.clear();
42        inputs
43    }
44
45    /// Call provider with streaming and emit events in real-time.
46    /// Also monitors pending_input_rx for real-time message appending.
47    pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
48        const MAX_RETRIES: u32 = 5;
49        const RETRY_DELAY_MS: u64 = 1000;
50
51        let mut attempt = 0;
52
53        loop {
54            attempt += 1;
55            log::info!(
56                "Agent: API call attempt {} with {} messages",
57                attempt,
58                request.messages.len()
59            );
60
61            if let Some(token) = &self.cancel_token
62                && token.is_cancelled()
63            {
64                return Err(anyhow::anyhow!("Operation cancelled"));
65            }
66
67            log::info!("Agent: calling provider.chat_stream");
68            let rx_result = self.provider.chat_stream(request.clone()).await;
69            log::info!("Agent: provider.chat_stream returned");
70
71            match rx_result {
72                Ok(mut rx) => {
73                    let mut response_content: Vec<ContentBlock> = Vec::new();
74                    let mut current_text = String::new();
75                    let mut current_thinking = String::new();
76                    let mut usage = Usage {
77                        input_tokens: 0,
78                        output_tokens: 0,
79                        cache_creation_input_tokens: 0,
80                        cache_read_input_tokens: 0,
81                    };
82                    let mut should_retry = false;
83
84                    loop {
85                        // Use select! with cancellation and pending input checks
86                        let event = if let Some(token) = &self.cancel_token {
87                            tokio::select! {
88                                // Primary: receive stream event
89                                event = rx.recv() => event,
90                                // Check for pending inputs periodically
91                                _ = sleep(Duration::from_millis(50)) => {
92                                    self.drain_pending_inputs();
93                                    continue;
94                                }
95                                // Cancellation signal
96                                _ = wait_for_cancel_stream(token) => {
97                                    return Err(anyhow::anyhow!("Operation cancelled"));
98                                }
99                            }
100                        } else {
101                            // No cancellation token, but still check pending inputs
102                            tokio::select! {
103                                event = rx.recv() => event,
104                                _ = sleep(Duration::from_millis(50)) => {
105                                    self.drain_pending_inputs();
106                                    continue;
107                                }
108                            }
109                        };
110
111                        match event {
112                            None => break,
113                            Some(StreamEvent::FirstByte) => {}
114                            Some(StreamEvent::ThinkingDelta(delta)) => {
115                                // Check cancellation before emitting
116                                if let Some(token) = &self.cancel_token
117                                    && token.is_cancelled()
118                                {
119                                    return Err(anyhow::anyhow!("Operation cancelled"));
120                                }
121                                if current_thinking.is_empty() {
122                                    self.emit(AgentEvent::thinking_start())?;
123                                }
124                                current_thinking.push_str(&delta);
125                                self.emit(AgentEvent::thinking_delta(delta, None))?;
126                            }
127                            Some(StreamEvent::TextDelta(delta)) => {
128                                // Check cancellation before emitting
129                                if let Some(token) = &self.cancel_token
130                                    && token.is_cancelled()
131                                {
132                                    return Err(anyhow::anyhow!("Operation cancelled"));
133                                }
134                                if current_text.is_empty() {
135                                    self.emit(AgentEvent::text_start())?;
136                                }
137                                current_text.push_str(&delta);
138                                self.emit(AgentEvent::text_delta(delta))?;
139                            }
140                            Some(StreamEvent::ToolUseStart { id, name }) => {
141                                // Emit events for UI but don't push content blocks
142                                // Content will be added from Done event's resp.content
143                                if !current_thinking.is_empty() {
144                                    self.emit(AgentEvent::thinking_end())?;
145                                    // Don't push - will be added from resp.content
146                                }
147                                if !current_text.is_empty() {
148                                    self.emit(AgentEvent::text_end())?;
149                                    // Don't push - will be added from resp.content
150                                }
151                                self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
152                            }
153                            Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
154                            Some(StreamEvent::ToolInputComplete { id, name, input }) => {
155                                self.previewed_tool_inputs.insert(id.clone());
156                                self.emit(AgentEvent::tool_use_start(&id, &name, Some(input)))?;
157                            }
158                            Some(StreamEvent::Usage { output_tokens }) => {
159                                self.emit(AgentEvent::usage_with_cache(
160                                    0,
161                                    output_tokens as u64,
162                                    0,
163                                    0,
164                                ))?;
165                                usage.output_tokens = output_tokens;
166                            }
167                            Some(StreamEvent::Done(resp)) => {
168                                // Check cancellation before processing final response
169                                if let Some(token) = &self.cancel_token
170                                    && token.is_cancelled()
171                                {
172                                    return Err(anyhow::anyhow!("Operation cancelled"));
173                                }
174
175                                // Final drain of pending inputs before completing
176                                self.drain_pending_inputs();
177
178                                // Don't add current_thinking/current_text here - use resp.content directly
179                                // This avoids duplicates since resp.content contains everything
180                                // Just emit events for UI updates if we have pending content
181                                if !current_thinking.is_empty() {
182                                    self.emit(AgentEvent::thinking_end())?;
183                                    // Don't push to response_content - will be added from resp.content
184                                }
185                                if !current_text.is_empty() {
186                                    self.emit(AgentEvent::text_end())?;
187                                    // Don't push to response_content - will be added from resp.content
188                                }
189
190                                // Add all blocks from final response with smart deduplication
191                                for block in &resp.content {
192                                    // Smart deduplication: compare content, not entire block
193                                    let is_duplicate = response_content.iter().any(|b| {
194                                        match (b, block) {
195                                            // For Thinking blocks, compare thinking content only (signature may differ)
196                                            (
197                                                ContentBlock::Thinking { thinking: t1, .. },
198                                                ContentBlock::Thinking { thinking: t2, .. },
199                                            ) => t1 == t2,
200                                            // For Text blocks, compare text content
201                                            (
202                                                ContentBlock::Text { text: t1 },
203                                                ContentBlock::Text { text: t2 },
204                                            ) => t1 == t2,
205                                            // For ToolUse, compare id
206                                            (
207                                                ContentBlock::ToolUse { id: id1, .. },
208                                                ContentBlock::ToolUse { id: id2, .. },
209                                            ) => id1 == id2,
210                                            // For ToolResult, compare tool_use_id
211                                            (
212                                                ContentBlock::ToolResult {
213                                                    tool_use_id: id1, ..
214                                                },
215                                                ContentBlock::ToolResult {
216                                                    tool_use_id: id2, ..
217                                                },
218                                            ) => id1 == id2,
219                                            // Default: exact comparison
220                                            _ => b == block,
221                                        }
222                                    });
223                                    if !is_duplicate {
224                                        response_content.push(block.clone());
225                                    }
226                                }
227                                usage = resp.usage;
228                            }
229                            Some(StreamEvent::Error(msg)) => {
230                                if attempt < MAX_RETRIES {
231                                    self.emit(AgentEvent::progress(
232                                        format!(
233                                            "⚠️ Stream error, retrying ({}/{}): {}",
234                                            attempt, MAX_RETRIES, &msg
235                                        ),
236                                        None,
237                                    ))?;
238                                    let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
239                                    tokio::time::sleep(tokio::time::Duration::from_millis(delay))
240                                        .await;
241                                    should_retry = true;
242                                    break;
243                                } else {
244                                    self.emit(AgentEvent::error(msg.clone(), None, None))?;
245                                    return Err(anyhow::anyhow!(
246                                        "Stream error after {} retries: {}",
247                                        MAX_RETRIES,
248                                        msg
249                                    ));
250                                }
251                            }
252                        }
253                    }
254
255                    if should_retry {
256                        continue;
257                    }
258
259                    return Ok(ChatResponse {
260                        content: response_content,
261                        stop_reason: StopReason::EndTurn,
262                        usage,
263                    });
264                }
265                Err(e) => {
266                    if attempt < MAX_RETRIES {
267                        let error_msg = e.to_string();
268                        self.emit(AgentEvent::progress(
269                            format!(
270                                "⚠️ API error, retrying ({}/{}): {}",
271                                attempt, MAX_RETRIES, &error_msg
272                            ),
273                            None,
274                        ))?;
275                        let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
276                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
277                    } else {
278                        return Err(anyhow::anyhow!(
279                            "API error after {} retries: {}",
280                            MAX_RETRIES,
281                            e
282                        ));
283                    }
284                }
285            }
286        }
287    }
288}