Skip to main content

matrixcode_core/agent/
helpers.rs

1//! Agent helper functions and utilities.
2
3use anyhow::Result;
4use std::sync::atomic::Ordering;
5use tokio::sync::mpsc;
6
7use crate::event::AgentEvent;
8use crate::providers::{ContentBlock, MessageContent, Role, Usage};
9use crate::truncate::truncate_chars;
10
11use super::types::Agent;
12
13/// Context information for display
14pub struct ContextInfo {
15    /// Message count
16    pub message_count: usize,
17    /// Estimated input tokens
18    pub estimated_input_tokens: u64,
19    /// Total input tokens (lifetime)
20    pub total_input_tokens: u64,
21    /// Total output tokens (lifetime)
22    pub total_output_tokens: u64,
23    /// System prompt preview
24    pub system_prompt_preview: String,
25    /// Memory summary
26    pub memory_summary: Option<String>,
27    /// Project overview preview
28    pub project_overview_preview: Option<String>,
29    /// Last few messages preview
30    pub recent_messages_preview: Vec<String>,
31    /// Model name
32    pub model_name: String,
33    /// Max tokens setting
34    pub max_tokens: u32,
35}
36
37impl Agent {
38    /// Get context information for display
39    pub fn get_context_info(&self) -> ContextInfo {
40        // Estimate tokens from messages
41        let estimated_tokens = self.messages.iter()
42            .map(|m| {
43                let content = match &m.content {
44                    MessageContent::Text(t) => t.len(),
45                    MessageContent::Blocks(blocks) => {
46                        blocks.iter()
47                            .filter_map(|b| {
48                                if let ContentBlock::Text { text } = b {
49                                    Some(text.len())
50                                } else {
51                                    None
52                                }
53                            })
54                            .sum::<usize>()
55                    }
56                };
57                // Rough estimate: ~3 chars per token + 50 for metadata
58                (content / 3 + 50) as u64
59            })
60            .sum();
61
62        // System prompt preview (first 500 chars)
63        let system_preview = truncate_chars(&self.system_prompt, 500);
64
65        // Project overview preview
66        let project_preview = self.project_overview.as_ref()
67            .map(|o| truncate_chars(o, 300));
68
69        // Recent messages preview (last 5 messages)
70        let recent_preview = self.messages.iter().rev().take(5).rev()
71            .map(|m| {
72                let role = match m.role {
73                    Role::User => "User",
74                    Role::Assistant => "Assistant",
75                    Role::System => "System",
76                    Role::Tool => "Tool",
77                };
78                let content_preview = match &m.content {
79                    MessageContent::Text(t) => truncate_chars(t, 100),
80                    MessageContent::Blocks(blocks) => {
81                        let text = blocks.iter()
82                            .filter_map(|b| {
83                                if let ContentBlock::Text { text } = b {
84                                    Some(text.clone())
85                                } else {
86                                    None
87                                }
88                            })
89                            .collect::<Vec<_>>()
90                            .join(" ");
91                        truncate_chars(&text, 100)
92                    }
93                };
94                format!("{}: {}", role, content_preview)
95            })
96            .collect();
97
98        ContextInfo {
99            message_count: self.messages.len(),
100            estimated_input_tokens: estimated_tokens,
101            total_input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
102            total_output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
103            system_prompt_preview: system_preview,
104            memory_summary: self.memory_summary.clone(),
105            project_overview_preview: project_preview,
106            recent_messages_preview: recent_preview,
107            model_name: self.model_name.clone(),
108            max_tokens: self.max_tokens,
109        }
110    }
111
112    /// Get full context preview (everything that will be sent to LLM)
113    pub fn get_full_context_preview(&self) -> String {
114        let mut preview = String::new();
115
116        // System prompt
117        preview.push_str("=== SYSTEM PROMPT ===\n");
118        preview.push_str(&self.system_prompt);
119        preview.push_str("\n\n");
120
121        // Memory summary
122        if let Some(memory) = &self.memory_summary {
123            preview.push_str("=== MEMORY SUMMARY ===\n");
124            preview.push_str(memory);
125            preview.push_str("\n\n");
126        }
127
128        // Project overview
129        if let Some(overview) = &self.project_overview {
130            preview.push_str("=== PROJECT OVERVIEW ===\n");
131            preview.push_str(overview);
132            preview.push_str("\n\n");
133        }
134
135        // Messages
136        preview.push_str("=== MESSAGES ===\n");
137        for (i, msg) in self.messages.iter().enumerate() {
138            let role = match msg.role {
139                Role::User => "User",
140                Role::Assistant => "Assistant",
141                Role::System => "System",
142                Role::Tool => "Tool",
143            };
144            preview.push_str(&format!("\n[{}] {}:\n", i + 1, role));
145
146            match &msg.content {
147                MessageContent::Text(t) => {
148                    preview.push_str(t);
149                }
150                MessageContent::Blocks(blocks) => {
151                    for block in blocks {
152                        match block {
153                            ContentBlock::Text { text } => {
154                                preview.push_str(text);
155                                preview.push_str("\n");
156                            }
157                            ContentBlock::ToolUse { name, input, .. } => {
158                                preview.push_str(&format!("[Tool: {}]\n", name));
159                                preview.push_str(&serde_json::to_string_pretty(input).unwrap_or_default());
160                                preview.push_str("\n");
161                            }
162                            ContentBlock::ToolResult { tool_use_id, content } => {
163                                preview.push_str(&format!("[Tool Result: {}]\n", tool_use_id));
164                                preview.push_str(content);
165                                preview.push_str("\n");
166                            }
167                            // Thinking blocks are NOT sent to LLM, skip them
168                            ContentBlock::Thinking { .. } => {
169                                continue;
170                            }
171                            ContentBlock::ServerToolUse { name, .. } => {
172                                preview.push_str(&format!("[Server Tool: {}]\n", name));
173                            }
174                            ContentBlock::ServerToolResult { tool_use_id, content, .. } => {
175                                preview.push_str(&format!("[Server Tool Result: {}]\n", tool_use_id));
176                                preview.push_str(content);
177                                preview.push_str("\n");
178                            }
179                            _ => {
180                                continue; // Skip other non-sendable content
181                            }
182                        }
183                    }
184                }
185            }
186        }
187
188        preview
189    }
190    /// Track token usage
191    pub(crate) fn track_usage(&self, usage: &Usage) {
192        self.total_input_tokens
193            .fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
194        self.total_output_tokens
195            .fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
196        self.last_input_tokens
197            .store(usage.input_tokens as u64, Ordering::Relaxed);
198
199        crate::debug::debug_log().log(
200            "usage",
201            &format!(
202                "tracked: input_tokens={}, output_tokens={}, cache_read={}, cache_created={}",
203                usage.input_tokens,
204                usage.output_tokens,
205                usage.cache_read_input_tokens,
206                usage.cache_creation_input_tokens
207            ),
208        );
209
210        let _ = self.event_tx.try_send(AgentEvent::usage_with_cache(
211            self.total_input_tokens.load(Ordering::Relaxed),
212            usage.output_tokens as u64,
213            usage.cache_read_input_tokens as u64,
214            usage.cache_creation_input_tokens as u64,
215        ));
216    }
217
218    /// Emit event (non-blocking)
219    pub(crate) fn emit(&self, event: AgentEvent) -> Result<()> {
220        log::debug!("Agent emit: event_type={:?}", event.event_type);
221        match self.event_tx.try_send(event) {
222            Ok(_) => {
223                log::debug!("Agent emit: sent successfully");
224                Ok(())
225            }
226            Err(mpsc::error::TrySendError::Full(_)) => {
227                log::warn!("Agent emit: channel full, skipping event");
228                Ok(())
229            }
230            Err(mpsc::error::TrySendError::Closed(_)) => {
231                log::error!("Agent emit: channel closed");
232                Err(anyhow::anyhow!("Event channel closed"))
233            }
234        }
235    }
236
237    /// Get pending (uncompleted) todos from the most recent todo_write.
238    /// Returns list of (status, content) for non-completed tasks that haven't exceeded reminder limit.
239    /// Note: todo_write replaces the entire list each time, so only the last one matters.
240    /// 
241    /// # Arguments
242    /// * `todo_reminder_count` - Reference to the reminder counter map (immutable, will be cloned inside)
243    /// * `max_reminders` - Maximum number of reminders allowed per todo (default: 2)
244    /// 
245    /// # Returns
246    /// Tuple of (pending todos, whether all todos are at reminder limit)
247    pub(crate) fn get_pending_todos_with_limit(
248        &self,
249        todo_reminder_count: &std::collections::HashMap<String, usize>,
250        max_reminders: usize,
251    ) -> (Vec<(String, String)>, bool) {
252        // Find the most recent todo_write (current state)
253        for msg in self.messages.iter().rev().take(10) {
254            if let MessageContent::Blocks(blocks) = &msg.content {
255                for block in blocks {
256                    if let ContentBlock::ToolUse { name, input, .. } = block
257                        && name == "todo_write"
258                    {
259                        // Extract non-completed todos from this todo_write
260                        if let Some(todos) = input.get("todos").and_then(|t| t.as_array()) {
261                            let pending: Vec<(String, String)> = todos
262                                .iter()
263                                .filter_map(|todo| {
264                                    let status = todo.get("status").and_then(|s| s.as_str())?;
265                                    let content = todo.get("content").and_then(|c| c.as_str())?;
266                                    if status != "completed" {
267                                        Some((status.to_string(), content.to_string()))
268                                    } else {
269                                        None
270                                    }
271                                })
272                                .collect();
273                            
274                            // Check which todos are at the reminder limit
275                            let mut filtered_pending = Vec::new();
276                            let mut all_at_limit = true;
277                            
278                            for (status, content) in pending {
279                                let count = todo_reminder_count.get(&content).copied().unwrap_or(0);
280                                if count < max_reminders {
281                                    filtered_pending.push((status, content));
282                                    all_at_limit = false;
283                                }
284                            }
285                            
286                            return (filtered_pending, all_at_limit); // Return immediately - this is the current state
287                        }
288                    }
289                }
290            }
291        }
292        (Vec::new(), true)
293    }
294    
295    /// Check if the last user message was a todo reminder.
296    /// This prevents adding duplicate reminders in consecutive iterations.
297    pub(crate) fn last_message_was_todo_reminder(&self) -> bool {
298        // Check the last few messages for a todo reminder
299        for msg in self.messages.iter().rev().take(3) {
300            if msg.role == Role::User {
301                if let MessageContent::Text(text) = &msg.content {
302                    if text.contains("任务尚未完成") && text.contains("待办项需要处理") {
303                        return true;
304                    }
305                }
306            }
307        }
308        false
309    }
310}
311
312/// Extract tool detail for display
313pub(crate) fn extract_tool_detail(tool_name: &str, input: &serde_json::Value) -> Option<String> {
314    match tool_name.to_lowercase().as_str() {
315        "read" => input
316            .get("path")
317            .and_then(|v| v.as_str())
318            .map(|s| truncate_str(s, 50)),
319        "write" => input
320            .get("path")
321            .and_then(|v| v.as_str())
322            .map(|s| truncate_str(s, 50)),
323        "edit" | "multi_edit" => {
324            let path = input.get("path").and_then(|v| v.as_str());
325            let old = input.get("old_string").and_then(|v| v.as_str());
326            match (path, old) {
327                (Some(p), Some(o)) => Some(format!(
328                    "{}: \"{}\"",
329                    truncate_str(p, 30),
330                    truncate_str(o, 20)
331                )),
332                (Some(p), None) => Some(truncate_str(p, 50)),
333                _ => None,
334            }
335        }
336        "bash" => input
337            .get("command")
338            .and_then(|v| v.as_str())
339            .map(|s| truncate_str(s, 60)),
340        "search" | "grep" => input
341            .get("pattern")
342            .and_then(|v| v.as_str())
343            .map(|s| format!("\"{}\"", truncate_str(s, 30))),
344        "glob" => input
345            .get("pattern")
346            .and_then(|v| v.as_str())
347            .map(|s| truncate_str(s, 40)),
348        "ls" => input
349            .get("path")
350            .and_then(|v| v.as_str())
351            .map(|s| truncate_str(s, 50)),
352        "websearch" => input
353            .get("query")
354            .and_then(|v| v.as_str())
355            .map(|s| truncate_str(s, 40)),
356        "webfetch" => input
357            .get("url")
358            .and_then(|v| v.as_str())
359            .map(|s| truncate_str(s, 50)),
360        "task" => input
361            .get("description")
362            .and_then(|v| v.as_str())
363            .map(|s| truncate_str(s, 40)),
364        "task_create" => input
365            .get("description")
366            .and_then(|v| v.as_str())
367            .map(|s| truncate_str(s, 40)),
368        "task_get" | "task_stop" => input
369            .get("task_id")
370            .and_then(|v| v.as_str())
371            .map(|s| s.to_string()),
372        _ => None,
373    }
374}
375
376/// Truncate string at char boundary (using character count, not bytes)
377pub(crate) fn truncate_str(s: &str, max: usize) -> String {
378    truncate_chars(s, max)
379}