Skip to main content

matrixcode_core/agent/
helpers.rs

1//! Agent helper functions and utilities.
2
3use anyhow::Result;
4use tokio::sync::mpsc;
5
6use crate::event::{AgentEvent, EventType};
7use crate::providers::{ContentBlock, MessageContent, Role, Usage};
8use crate::truncate::truncate_chars;
9
10use super::types::Agent;
11
12/// Context information for display
13pub struct ContextInfo {
14    /// Message count
15    pub message_count: usize,
16    /// Estimated input tokens
17    pub estimated_input_tokens: u64,
18    /// Total input tokens (lifetime)
19    pub total_input_tokens: u64,
20    /// Total output tokens (lifetime)
21    pub total_output_tokens: u64,
22    /// System prompt preview
23    pub system_prompt_preview: String,
24    /// Memory summary
25    pub memory_summary: Option<String>,
26    /// Project overview preview
27    pub project_overview_preview: Option<String>,
28    /// Last few messages preview
29    pub recent_messages_preview: Vec<String>,
30    /// Model name
31    pub model_name: String,
32    /// Max tokens setting
33    pub max_tokens: u32,
34}
35
36impl Agent {
37    /// Get context information for display
38    pub fn get_context_info(&self) -> ContextInfo {
39        // Estimate tokens from messages
40        let estimated_tokens = self.messages().iter()
41            .map(|m| {
42                let content = match &m.content {
43                    MessageContent::Text(t) => t.len(),
44                    MessageContent::Blocks(blocks) => {
45                        blocks.iter()
46                            .filter_map(|b| {
47                                if let ContentBlock::Text { text } = b {
48                                    Some(text.len())
49                                } else {
50                                    None
51                                }
52                            })
53                            .sum::<usize>()
54                    }
55                };
56                // Rough estimate: ~3 chars per token + 50 for metadata
57                (content / 3 + 50) as u64
58            })
59            .sum();
60
61        // System prompt preview (first 500 chars)
62        let system_preview = truncate_chars(self.system_prompt(), 500);
63
64        // Project overview preview
65        let project_preview = self.project_overview()
66            .map(|o| truncate_chars(o, 300));
67
68        // Recent messages preview (last 5 messages)
69        let recent_preview = self.messages().iter().rev().take(5).rev()
70            .map(|m| {
71                let role = match m.role {
72                    Role::User => "User",
73                    Role::Assistant => "Assistant",
74                    Role::System => "System",
75                    Role::Tool => "Tool",
76                };
77                let content_preview = match &m.content {
78                    MessageContent::Text(t) => truncate_chars(t, 100),
79                    MessageContent::Blocks(blocks) => {
80                        let text = blocks.iter()
81                            .filter_map(|b| {
82                                if let ContentBlock::Text { text } = b {
83                                    Some(text.clone())
84                                } else {
85                                    None
86                                }
87                            })
88                            .collect::<Vec<_>>()
89                            .join(" ");
90                        truncate_chars(&text, 100)
91                    }
92                };
93                format!("{}: {}", role, content_preview)
94            })
95            .collect();
96
97        ContextInfo {
98            message_count: self.messages().len(),
99            estimated_input_tokens: estimated_tokens,
100            total_input_tokens: self.state.total_input_tokens(),
101            total_output_tokens: self.state.total_output_tokens(),
102            system_prompt_preview: system_preview,
103            memory_summary: self.memory_summary().map(|s| s.to_string()),
104            project_overview_preview: project_preview,
105            recent_messages_preview: recent_preview,
106            model_name: self.model_name.clone(),
107            max_tokens: self.max_tokens(),
108        }
109    }
110
111    /// Get full context preview (everything that will be sent to LLM)
112    pub fn get_full_context_preview(&self) -> String {
113        let mut preview = String::new();
114
115        // System prompt
116        preview.push_str("=== SYSTEM PROMPT ===\n");
117        preview.push_str(self.system_prompt());
118        preview.push_str("\n\n");
119
120        // Memory summary
121        if let Some(memory) = self.memory_summary() {
122            preview.push_str("=== MEMORY SUMMARY ===\n");
123            preview.push_str(memory);
124            preview.push_str("\n\n");
125        }
126
127        // Project overview
128        if let Some(overview) = self.project_overview() {
129            preview.push_str("=== PROJECT OVERVIEW ===\n");
130            preview.push_str(overview);
131            preview.push_str("\n\n");
132        }
133
134        // Messages
135        preview.push_str("=== MESSAGES ===\n");
136        for (i, msg) in self.messages().iter().enumerate() {
137            let role = match msg.role {
138                Role::User => "User",
139                Role::Assistant => "Assistant",
140                Role::System => "System",
141                Role::Tool => "Tool",
142            };
143            preview.push_str(&format!("\n[{}] {}:\n", i + 1, role));
144
145            match &msg.content {
146                MessageContent::Text(t) => {
147                    preview.push_str(t);
148                }
149                MessageContent::Blocks(blocks) => {
150                    for block in blocks {
151                        match block {
152                            ContentBlock::Text { text } => {
153                                preview.push_str(text);
154                                preview.push_str("\n");
155                            }
156                            ContentBlock::ToolUse { name, input, .. } => {
157                                preview.push_str(&format!("[Tool: {}]\n", name));
158                                preview.push_str(&serde_json::to_string_pretty(input).unwrap_or_default());
159                                preview.push_str("\n");
160                            }
161                            ContentBlock::ToolResult { tool_use_id, content } => {
162                                preview.push_str(&format!("[Tool Result: {}]\n", tool_use_id));
163                                preview.push_str(content);
164                                preview.push_str("\n");
165                            }
166                            // Thinking blocks are NOT sent to LLM, skip them
167                            ContentBlock::Thinking { .. } => {
168                                continue;
169                            }
170                            ContentBlock::ServerToolUse { name, .. } => {
171                                preview.push_str(&format!("[Server Tool: {}]\n", name));
172                            }
173                            ContentBlock::ServerToolResult { tool_use_id, content, .. } => {
174                                preview.push_str(&format!("[Server Tool Result: {}]\n", tool_use_id));
175                                preview.push_str(content);
176                                preview.push_str("\n");
177                            }
178                            _ => {
179                                continue; // Skip other non-sendable content
180                            }
181                        }
182                    }
183                }
184            }
185        }
186
187        preview
188    }
189    /// Track token usage
190    pub(crate) fn track_usage(&self, usage: &Usage) {
191        self.state.track_usage(usage);
192
193        crate::debug::debug_log().log(
194            "usage",
195            &format!(
196                "tracked: input_tokens={}, output_tokens={}, cache_read={}, cache_created={}",
197                usage.input_tokens,
198                usage.output_tokens,
199                usage.cache_read_input_tokens,
200                usage.cache_creation_input_tokens
201            ),
202        );
203
204        let _ = self.event_tx.try_send(AgentEvent::usage_with_cache(
205            self.state.total_input_tokens(),
206            usage.output_tokens as u64,
207            usage.cache_read_input_tokens as u64,
208            usage.cache_creation_input_tokens as u64,
209        ));
210    }
211
212    /// Emit event (with retry on full)
213    ///
214    /// This method tries to send an event and retries with backoff if the channel is full.
215    /// This prevents event loss which could cause UI state inconsistency.
216    pub(crate) fn emit(&self, event: AgentEvent) -> Result<()> {
217        log::debug!("Agent emit: event_type={:?}", event.event_type);
218
219        // First try non-blocking send for performance
220        match self.event_tx.try_send(event) {
221            Ok(_) => {
222                log::debug!("Agent emit: sent successfully");
223                Ok(())
224            }
225            Err(mpsc::error::TrySendError::Full(event)) => {
226                // Channel full - for critical events, we must retry
227                let is_critical = matches!(
228                    event.event_type,
229                    EventType::Error | EventType::SessionEnded | EventType::SessionStarted
230                );
231
232                if is_critical {
233                    // Retry a few times with short delays (blocking approach)
234                    let mut retries = 3;
235                    let mut current_event = event;
236                    while retries > 0 {
237                        // Short spin wait
238                        std::thread::sleep(std::time::Duration::from_millis(10));
239                        match self.event_tx.try_send(current_event) {
240                            Ok(_) => {
241                                log::debug!("Agent emit: critical event sent after retry");
242                                return Ok(());
243                            }
244                            Err(mpsc::error::TrySendError::Full(e)) => {
245                                current_event = e;
246                                retries -= 1;
247                            }
248                            Err(mpsc::error::TrySendError::Closed(_)) => {
249                                log::error!("Agent emit: channel closed during retry");
250                                return Err(anyhow::anyhow!("Event channel closed"));
251                            }
252                        }
253                    }
254                    log::warn!("Agent emit: critical event dropped after {} retries", 3);
255                    Err(anyhow::anyhow!("Event channel full, critical event dropped"))
256                } else {
257                    log::warn!("Agent emit: channel full, skipping non-critical event");
258                    Ok(())
259                }
260            }
261            Err(mpsc::error::TrySendError::Closed(_)) => {
262                log::error!("Agent emit: channel closed");
263                Err(anyhow::anyhow!("Event channel closed"))
264            }
265        }
266    }
267
268    /// Get pending (uncompleted) todos from the most recent todo_write.
269    /// Returns list of (status, content) for non-completed tasks that haven't exceeded reminder limit.
270    /// Note: todo_write replaces the entire list each time, so only the last one matters.
271    /// 
272    /// # Arguments
273    /// * `todo_reminder_count` - Reference to the reminder counter map (immutable, will be cloned inside)
274    /// * `max_reminders` - Maximum number of reminders allowed per todo (default: 2)
275    /// 
276    /// # Returns
277    /// Tuple of (pending todos, whether all todos are at reminder limit)
278    pub(crate) fn get_pending_todos_with_limit(
279        &self,
280        todo_reminder_count: &std::collections::HashMap<String, usize>,
281        max_reminders: usize,
282    ) -> (Vec<(String, String)>, bool) {
283        // Find the most recent todo_write (current state)
284        for msg in self.messages().iter().rev().take(10) {
285            if let MessageContent::Blocks(blocks) = &msg.content {
286                for block in blocks {
287                    if let ContentBlock::ToolUse { name, input, .. } = block
288                        && name == "todo_write"
289                    {
290                        // Extract non-completed todos from this todo_write
291                        if let Some(todos) = input.get("todos").and_then(|t| t.as_array()) {
292                            let pending: Vec<(String, String)> = todos
293                                .iter()
294                                .filter_map(|todo| {
295                                    let status = todo.get("status").and_then(|s| s.as_str())?;
296                                    let content = todo.get("content").and_then(|c| c.as_str())?;
297                                    if status != "completed" {
298                                        Some((status.to_string(), content.to_string()))
299                                    } else {
300                                        None
301                                    }
302                                })
303                                .collect();
304                            
305                            // Check which todos are at the reminder limit
306                            let mut filtered_pending = Vec::new();
307                            let mut all_at_limit = true;
308                            
309                            for (status, content) in pending {
310                                let count = todo_reminder_count.get(&content).copied().unwrap_or(0);
311                                if count < max_reminders {
312                                    filtered_pending.push((status, content));
313                                    all_at_limit = false;
314                                }
315                            }
316                            
317                            return (filtered_pending, all_at_limit); // Return immediately - this is the current state
318                        }
319                    }
320                }
321            }
322        }
323        (Vec::new(), true)
324    }
325    
326    /// Check if the last user message was a todo reminder.
327    /// This prevents adding duplicate reminders in consecutive iterations.
328    pub(crate) fn last_message_was_todo_reminder(&self) -> bool {
329        // Check the last few messages for a todo reminder
330        for msg in self.messages().iter().rev().take(3) {
331            if msg.role == Role::User {
332                if let MessageContent::Text(text) = &msg.content {
333                    if text.contains("任务尚未完成") && text.contains("待办项需要处理") {
334                        return true;
335                    }
336                }
337            }
338        }
339        false
340    }
341}
342
343/// Extract tool detail for display
344pub(crate) fn extract_tool_detail(tool_name: &str, input: &serde_json::Value) -> Option<String> {
345    match tool_name.to_lowercase().as_str() {
346        "read" => input
347            .get("path")
348            .and_then(|v| v.as_str())
349            .map(|s| truncate_str(s, 50)),
350        "write" => input
351            .get("path")
352            .and_then(|v| v.as_str())
353            .map(|s| truncate_str(s, 50)),
354        "edit" | "multi_edit" => {
355            let path = input.get("path").and_then(|v| v.as_str());
356            let old = input.get("old_string").and_then(|v| v.as_str());
357            match (path, old) {
358                (Some(p), Some(o)) => Some(format!(
359                    "{}: \"{}\"",
360                    truncate_str(p, 30),
361                    truncate_str(o, 20)
362                )),
363                (Some(p), None) => Some(truncate_str(p, 50)),
364                _ => None,
365            }
366        }
367        "bash" => input
368            .get("command")
369            .and_then(|v| v.as_str())
370            .map(|s| truncate_str(s, 60)),
371        "search" | "grep" => input
372            .get("pattern")
373            .and_then(|v| v.as_str())
374            .map(|s| format!("\"{}\"", truncate_str(s, 30))),
375        "glob" => input
376            .get("pattern")
377            .and_then(|v| v.as_str())
378            .map(|s| truncate_str(s, 40)),
379        "ls" => input
380            .get("path")
381            .and_then(|v| v.as_str())
382            .map(|s| truncate_str(s, 50)),
383        "websearch" => input
384            .get("query")
385            .and_then(|v| v.as_str())
386            .map(|s| truncate_str(s, 40)),
387        "webfetch" => input
388            .get("url")
389            .and_then(|v| v.as_str())
390            .map(|s| truncate_str(s, 50)),
391        "task" => input
392            .get("description")
393            .and_then(|v| v.as_str())
394            .map(|s| truncate_str(s, 40)),
395        "task_create" => input
396            .get("description")
397            .and_then(|v| v.as_str())
398            .map(|s| truncate_str(s, 40)),
399        "task_get" | "task_stop" => input
400            .get("task_id")
401            .and_then(|v| v.as_str())
402            .map(|s| s.to_string()),
403        _ => None,
404    }
405}
406
407/// Truncate string at char boundary (using character count, not bytes)
408pub(crate) fn truncate_str(s: &str, max: usize) -> String {
409    truncate_chars(s, max)
410}