Skip to main content

matrixcode_core/agent/
streaming.rs

1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep};
5
6use crate::event::AgentEvent;
7use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
8
9use super::types::Agent;
10
11/// Wait for cancellation signal, checking periodically.
12async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
13    while !token.is_cancelled() {
14        sleep(Duration::from_millis(100)).await;
15    }
16}
17
18impl Agent {
19    /// Drain any pending input messages from the channel.
20    /// Called during streaming to collect real-time appended messages.
21    fn drain_pending_inputs(&mut self) {
22        if let Some(rx) = &mut self.pending_input_rx {
23            while let Ok(msg) = rx.try_recv() {
24                log::info!("Agent received pending input: {}", msg.chars().take(50).collect::<String>());
25                self.pending_inputs.push(msg);
26            }
27        }
28    }
29
30    /// Check if there are pending inputs waiting to be processed.
31    pub fn has_pending_inputs(&self) -> bool {
32        !self.pending_inputs.is_empty()
33    }
34
35    /// Get and clear all pending inputs.
36    pub fn take_pending_inputs(&mut self) -> Vec<String> {
37        let inputs = self.pending_inputs.clone();
38        self.pending_inputs.clear();
39        inputs
40    }
41
42    /// Call provider with streaming and emit events in real-time.
43    /// Also monitors pending_input_rx for real-time message appending.
44    pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
45        const MAX_RETRIES: u32 = 5;
46        const RETRY_DELAY_MS: u64 = 1000;
47
48        let mut attempt = 0;
49
50        loop {
51            attempt += 1;
52            log::info!("Agent: API call attempt {} with {} messages", attempt, request.messages.len());
53
54            if let Some(token) = &self.cancel_token
55                && token.is_cancelled()
56            {
57                return Err(anyhow::anyhow!("Operation cancelled"));
58            }
59
60            log::info!("Agent: calling provider.chat_stream");
61            let rx_result = self.provider.chat_stream(request.clone()).await;
62            log::info!("Agent: provider.chat_stream returned");
63
64            match rx_result {
65                Ok(mut rx) => {
66                    let mut response_content: Vec<ContentBlock> = Vec::new();
67                    let mut current_text = String::new();
68                    let mut current_thinking = String::new();
69                    let mut usage = Usage {
70                        input_tokens: 0,
71                        output_tokens: 0,
72                        cache_creation_input_tokens: 0,
73                        cache_read_input_tokens: 0,
74                    };
75                    let mut should_retry = false;
76
77                    loop {
78                        // Use select! with cancellation and pending input checks
79                        let event = if let Some(token) = &self.cancel_token {
80                            tokio::select! {
81                                // Primary: receive stream event
82                                event = rx.recv() => event,
83                                // Check for pending inputs periodically
84                                _ = sleep(Duration::from_millis(50)) => {
85                                    self.drain_pending_inputs();
86                                    continue;
87                                }
88                                // Cancellation signal
89                                _ = wait_for_cancel_stream(token) => {
90                                    return Err(anyhow::anyhow!("Operation cancelled"));
91                                }
92                            }
93                        } else {
94                            // No cancellation token, but still check pending inputs
95                            tokio::select! {
96                                event = rx.recv() => event,
97                                _ = sleep(Duration::from_millis(50)) => {
98                                    self.drain_pending_inputs();
99                                    continue;
100                                }
101                            }
102                        };
103
104                        match event {
105                            None => break,
106                            Some(StreamEvent::FirstByte) => {}
107                            Some(StreamEvent::ThinkingDelta(delta)) => {
108                                // Check cancellation before emitting
109                                if let Some(token) = &self.cancel_token
110                                    && token.is_cancelled()
111                                {
112                                    return Err(anyhow::anyhow!("Operation cancelled"));
113                                }
114                                if current_thinking.is_empty() {
115                                    self.emit(AgentEvent::thinking_start())?;
116                                }
117                                current_thinking.push_str(&delta);
118                                self.emit(AgentEvent::thinking_delta(delta, None))?;
119                            }
120                            Some(StreamEvent::TextDelta(delta)) => {
121                                // Check cancellation before emitting
122                                if let Some(token) = &self.cancel_token
123                                    && token.is_cancelled()
124                                {
125                                    return Err(anyhow::anyhow!("Operation cancelled"));
126                                }
127                                if current_text.is_empty() {
128                                    self.emit(AgentEvent::text_start())?;
129                                }
130                                current_text.push_str(&delta);
131                                self.emit(AgentEvent::text_delta(delta))?;
132                            }
133                            Some(StreamEvent::ToolUseStart { id, name }) => {
134                                // Emit events for UI but don't push content blocks
135                                // Content will be added from Done event's resp.content
136                                if !current_thinking.is_empty() {
137                                    self.emit(AgentEvent::thinking_end())?;
138                                    // Don't push - will be added from resp.content
139                                }
140                                if !current_text.is_empty() {
141                                    self.emit(AgentEvent::text_end())?;
142                                    // Don't push - will be added from resp.content
143                                }
144                                self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
145                            }
146                            Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
147                            Some(StreamEvent::Usage { output_tokens }) => {
148                                self.emit(AgentEvent::usage_with_cache(
149                                    0,
150                                    output_tokens as u64,
151                                    0,
152                                    0,
153                                ))?;
154                                usage.output_tokens = output_tokens;
155                            }
156                            Some(StreamEvent::Done(resp)) => {
157                                // Check cancellation before processing final response
158                                if let Some(token) = &self.cancel_token
159                                    && token.is_cancelled()
160                                {
161                                    return Err(anyhow::anyhow!("Operation cancelled"));
162                                }
163
164                                // Final drain of pending inputs before completing
165                                self.drain_pending_inputs();
166
167                                // Don't add current_thinking/current_text here - use resp.content directly
168                                // This avoids duplicates since resp.content contains everything
169                                // Just emit events for UI updates if we have pending content
170                                if !current_thinking.is_empty() {
171                                    self.emit(AgentEvent::thinking_end())?;
172                                    // Don't push to response_content - will be added from resp.content
173                                }
174                                if !current_text.is_empty() {
175                                    self.emit(AgentEvent::text_end())?;
176                                    // Don't push to response_content - will be added from resp.content
177                                }
178
179                                // Add all blocks from final response with smart deduplication
180                                for block in &resp.content {
181                                    // Smart deduplication: compare content, not entire block
182                                    let is_duplicate = response_content.iter().any(|b| {
183                                        match (b, block) {
184                                            // For Thinking blocks, compare thinking content only (signature may differ)
185                                            (ContentBlock::Thinking { thinking: t1, .. }, ContentBlock::Thinking { thinking: t2, .. }) => {
186                                                t1 == t2
187                                            }
188                                            // For Text blocks, compare text content
189                                            (ContentBlock::Text { text: t1 }, ContentBlock::Text { text: t2 }) => {
190                                                t1 == t2
191                                            }
192                                            // For ToolUse, compare id
193                                            (ContentBlock::ToolUse { id: id1, .. }, ContentBlock::ToolUse { id: id2, .. }) => {
194                                                id1 == id2
195                                            }
196                                            // For ToolResult, compare tool_use_id
197                                            (ContentBlock::ToolResult { tool_use_id: id1, .. }, ContentBlock::ToolResult { tool_use_id: id2, .. }) => {
198                                                id1 == id2
199                                            }
200                                            // Default: exact comparison
201                                            _ => b == block
202                                        }
203                                    });
204                                    if !is_duplicate {
205                                        response_content.push(block.clone());
206                                    }
207                                }
208                                usage = resp.usage;
209                            }
210                            Some(StreamEvent::Error(msg)) => {
211                                if attempt < MAX_RETRIES {
212                                    self.emit(AgentEvent::progress(
213                                        format!(
214                                            "⚠️ Stream error, retrying ({}/{}): {}",
215                                            attempt, MAX_RETRIES, &msg
216                                        ),
217                                        None,
218                                    ))?;
219                                    let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
220                                    tokio::time::sleep(tokio::time::Duration::from_millis(delay))
221                                        .await;
222                                    should_retry = true;
223                                    break;
224                                } else {
225                                    self.emit(AgentEvent::error(msg.clone(), None, None))?;
226                                    return Err(anyhow::anyhow!(
227                                        "Stream error after {} retries: {}",
228                                        MAX_RETRIES,
229                                        msg
230                                    ));
231                                }
232                            }
233                        }
234                    }
235
236                    if should_retry {
237                        continue;
238                    }
239
240                    return Ok(ChatResponse {
241                        content: response_content,
242                        stop_reason: StopReason::EndTurn,
243                        usage,
244                    });
245                }
246                Err(e) => {
247                    if attempt < MAX_RETRIES {
248                        let error_msg = e.to_string();
249                        self.emit(AgentEvent::progress(
250                            format!(
251                                "⚠️ API error, retrying ({}/{}): {}",
252                                attempt, MAX_RETRIES, &error_msg
253                            ),
254                            None,
255                        ))?;
256                        let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
257                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
258                    } else {
259                        return Err(anyhow::anyhow!(
260                            "API error after {} retries: {}",
261                            MAX_RETRIES,
262                            e
263                        ));
264                    }
265                }
266            }
267        }
268    }
269}