Skip to main content

matrixcode_core/agent/
helpers.rs

1//! Agent helper functions and utilities.
2
3use anyhow::Result;
4use std::sync::atomic::Ordering;
5use tokio::sync::mpsc;
6
7use crate::event::AgentEvent;
8use crate::providers::{ContentBlock, MessageContent, Role, Usage};
9use crate::truncate::truncate_chars;
10
11use super::types::Agent;
12
13/// Context information for display
14pub struct ContextInfo {
15    /// Message count
16    pub message_count: usize,
17    /// Estimated input tokens
18    pub estimated_input_tokens: u64,
19    /// Total input tokens (lifetime)
20    pub total_input_tokens: u64,
21    /// Total output tokens (lifetime)
22    pub total_output_tokens: u64,
23    /// System prompt preview
24    pub system_prompt_preview: String,
25    /// Memory summary
26    pub memory_summary: Option<String>,
27    /// Project overview preview
28    pub project_overview_preview: Option<String>,
29    /// Last few messages preview
30    pub recent_messages_preview: Vec<String>,
31    /// Model name
32    pub model_name: String,
33    /// Max tokens setting
34    pub max_tokens: u32,
35}
36
37impl Agent {
38    /// Get context information for display
39    pub fn get_context_info(&self) -> ContextInfo {
40        // Estimate tokens from messages
41        let estimated_tokens = self.messages.iter()
42            .map(|m| {
43                let content = match &m.content {
44                    MessageContent::Text(t) => t.len(),
45                    MessageContent::Blocks(blocks) => {
46                        blocks.iter()
47                            .filter_map(|b| {
48                                if let ContentBlock::Text { text } = b {
49                                    Some(text.len())
50                                } else {
51                                    None
52                                }
53                            })
54                            .sum::<usize>()
55                    }
56                };
57                // Rough estimate: ~3 chars per token + 50 for metadata
58                (content / 3 + 50) as u64
59            })
60            .sum();
61
62        // System prompt preview (first 500 chars)
63        let system_preview = truncate_chars(&self.system_prompt, 500);
64
65        // Project overview preview
66        let project_preview = self.project_overview.as_ref()
67            .map(|o| truncate_chars(o, 300));
68
69        // Recent messages preview (last 5 messages)
70        let recent_preview = self.messages.iter().rev().take(5).rev()
71            .map(|m| {
72                let role = match m.role {
73                    Role::User => "User",
74                    Role::Assistant => "Assistant",
75                    Role::System => "System",
76                    Role::Tool => "Tool",
77                };
78                let content_preview = match &m.content {
79                    MessageContent::Text(t) => truncate_chars(t, 100),
80                    MessageContent::Blocks(blocks) => {
81                        let text = blocks.iter()
82                            .filter_map(|b| {
83                                if let ContentBlock::Text { text } = b {
84                                    Some(text.clone())
85                                } else {
86                                    None
87                                }
88                            })
89                            .collect::<Vec<_>>()
90                            .join(" ");
91                        truncate_chars(&text, 100)
92                    }
93                };
94                format!("{}: {}", role, content_preview)
95            })
96            .collect();
97
98        ContextInfo {
99            message_count: self.messages.len(),
100            estimated_input_tokens: estimated_tokens,
101            total_input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
102            total_output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
103            system_prompt_preview: system_preview,
104            memory_summary: self.memory_summary.clone(),
105            project_overview_preview: project_preview,
106            recent_messages_preview: recent_preview,
107            model_name: self.model_name.clone(),
108            max_tokens: self.max_tokens,
109        }
110    }
111
112    /// Get full context preview (everything that will be sent to LLM)
113    pub fn get_full_context_preview(&self) -> String {
114        let mut preview = String::new();
115
116        // System prompt
117        preview.push_str("=== SYSTEM PROMPT ===\n");
118        preview.push_str(&self.system_prompt);
119        preview.push_str("\n\n");
120
121        // Memory summary
122        if let Some(memory) = &self.memory_summary {
123            preview.push_str("=== MEMORY SUMMARY ===\n");
124            preview.push_str(memory);
125            preview.push_str("\n\n");
126        }
127
128        // Project overview
129        if let Some(overview) = &self.project_overview {
130            preview.push_str("=== PROJECT OVERVIEW ===\n");
131            preview.push_str(overview);
132            preview.push_str("\n\n");
133        }
134
135        // Messages
136        preview.push_str("=== MESSAGES ===\n");
137        for (i, msg) in self.messages.iter().enumerate() {
138            let role = match msg.role {
139                Role::User => "User",
140                Role::Assistant => "Assistant",
141                Role::System => "System",
142                Role::Tool => "Tool",
143            };
144            preview.push_str(&format!("\n[{}] {}:\n", i + 1, role));
145
146            match &msg.content {
147                MessageContent::Text(t) => {
148                    preview.push_str(t);
149                }
150                MessageContent::Blocks(blocks) => {
151                    for block in blocks {
152                        match block {
153                            ContentBlock::Text { text } => {
154                                preview.push_str(text);
155                                preview.push_str("\n");
156                            }
157                            ContentBlock::ToolUse { name, input, .. } => {
158                                preview.push_str(&format!("[Tool: {}]\n", name));
159                                preview.push_str(&serde_json::to_string_pretty(input).unwrap_or_default());
160                                preview.push_str("\n");
161                            }
162                            ContentBlock::ToolResult { tool_use_id, content } => {
163                                preview.push_str(&format!("[Tool Result: {}]\n", tool_use_id));
164                                preview.push_str(content);
165                                preview.push_str("\n");
166                            }
167                            // Handle other block types
168                            ContentBlock::Thinking { thinking, .. } => {
169                                preview.push_str("[Thinking]\n");
170                                preview.push_str(thinking);
171                                preview.push_str("\n");
172                            }
173                            ContentBlock::ServerToolUse { name, .. } => {
174                                preview.push_str(&format!("[Server Tool: {}]\n", name));
175                            }
176                            ContentBlock::ServerToolResult { tool_use_id, content, .. } => {
177                                preview.push_str(&format!("[Server Tool Result: {}]\n", tool_use_id));
178                                preview.push_str(content);
179                                preview.push_str("\n");
180                            }
181                            _ => {
182                                preview.push_str("[Other Content]\n");
183                            }
184                        }
185                    }
186                }
187            }
188        }
189
190        preview
191    }
192    /// Track token usage
193    pub(crate) fn track_usage(&self, usage: &Usage) {
194        self.total_input_tokens
195            .fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
196        self.total_output_tokens
197            .fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
198        self.last_input_tokens
199            .store(usage.input_tokens as u64, Ordering::Relaxed);
200
201        crate::debug::debug_log().log(
202            "usage",
203            &format!(
204                "tracked: input_tokens={}, output_tokens={}, cache_read={}, cache_created={}",
205                usage.input_tokens,
206                usage.output_tokens,
207                usage.cache_read_input_tokens,
208                usage.cache_creation_input_tokens
209            ),
210        );
211
212        let _ = self.event_tx.try_send(AgentEvent::usage_with_cache(
213            self.total_input_tokens.load(Ordering::Relaxed),
214            usage.output_tokens as u64,
215            usage.cache_read_input_tokens as u64,
216            usage.cache_creation_input_tokens as u64,
217        ));
218    }
219
220    /// Emit event (non-blocking)
221    pub(crate) fn emit(&self, event: AgentEvent) -> Result<()> {
222        log::debug!("Agent emit: event_type={:?}", event.event_type);
223        match self.event_tx.try_send(event) {
224            Ok(_) => {
225                log::debug!("Agent emit: sent successfully");
226                Ok(())
227            }
228            Err(mpsc::error::TrySendError::Full(_)) => {
229                log::warn!("Agent emit: channel full, skipping event");
230                Ok(())
231            }
232            Err(mpsc::error::TrySendError::Closed(_)) => {
233                log::error!("Agent emit: channel closed");
234                Err(anyhow::anyhow!("Event channel closed"))
235            }
236        }
237    }
238
239    /// Get pending (uncompleted) todos from the most recent todo_write.
240    /// Returns list of (status, content) for non-completed tasks that haven't exceeded reminder limit.
241    /// Note: todo_write replaces the entire list each time, so only the last one matters.
242    /// 
243    /// # Arguments
244    /// * `todo_reminder_count` - Reference to the reminder counter map (immutable, will be cloned inside)
245    /// * `max_reminders` - Maximum number of reminders allowed per todo (default: 2)
246    /// 
247    /// # Returns
248    /// Tuple of (pending todos, whether all todos are at reminder limit)
249    pub(crate) fn get_pending_todos_with_limit(
250        &self,
251        todo_reminder_count: &std::collections::HashMap<String, usize>,
252        max_reminders: usize,
253    ) -> (Vec<(String, String)>, bool) {
254        // Find the most recent todo_write (current state)
255        for msg in self.messages.iter().rev().take(10) {
256            if let MessageContent::Blocks(blocks) = &msg.content {
257                for block in blocks {
258                    if let ContentBlock::ToolUse { name, input, .. } = block
259                        && name == "todo_write"
260                    {
261                        // Extract non-completed todos from this todo_write
262                        if let Some(todos) = input.get("todos").and_then(|t| t.as_array()) {
263                            let pending: Vec<(String, String)> = todos
264                                .iter()
265                                .filter_map(|todo| {
266                                    let status = todo.get("status").and_then(|s| s.as_str())?;
267                                    let content = todo.get("content").and_then(|c| c.as_str())?;
268                                    if status != "completed" {
269                                        Some((status.to_string(), content.to_string()))
270                                    } else {
271                                        None
272                                    }
273                                })
274                                .collect();
275                            
276                            // Check which todos are at the reminder limit
277                            let mut filtered_pending = Vec::new();
278                            let mut all_at_limit = true;
279                            
280                            for (status, content) in pending {
281                                let count = todo_reminder_count.get(&content).copied().unwrap_or(0);
282                                if count < max_reminders {
283                                    filtered_pending.push((status, content));
284                                    all_at_limit = false;
285                                }
286                            }
287                            
288                            return (filtered_pending, all_at_limit); // Return immediately - this is the current state
289                        }
290                    }
291                }
292            }
293        }
294        (Vec::new(), true)
295    }
296    
297    /// Check if the last user message was a todo reminder.
298    /// This prevents adding duplicate reminders in consecutive iterations.
299    pub(crate) fn last_message_was_todo_reminder(&self) -> bool {
300        // Check the last few messages for a todo reminder
301        for msg in self.messages.iter().rev().take(3) {
302            if msg.role == Role::User {
303                if let MessageContent::Text(text) = &msg.content {
304                    if text.contains("任务尚未完成") && text.contains("待办项需要处理") {
305                        return true;
306                    }
307                }
308            }
309        }
310        false
311    }
312}
313
314/// Extract tool detail for display
315pub(crate) fn extract_tool_detail(tool_name: &str, input: &serde_json::Value) -> Option<String> {
316    match tool_name.to_lowercase().as_str() {
317        "read" => input
318            .get("path")
319            .and_then(|v| v.as_str())
320            .map(|s| truncate_str(s, 50)),
321        "write" => input
322            .get("path")
323            .and_then(|v| v.as_str())
324            .map(|s| truncate_str(s, 50)),
325        "edit" | "multi_edit" => {
326            let path = input.get("path").and_then(|v| v.as_str());
327            let old = input.get("old_string").and_then(|v| v.as_str());
328            match (path, old) {
329                (Some(p), Some(o)) => Some(format!(
330                    "{}: \"{}\"",
331                    truncate_str(p, 30),
332                    truncate_str(o, 20)
333                )),
334                (Some(p), None) => Some(truncate_str(p, 50)),
335                _ => None,
336            }
337        }
338        "bash" => input
339            .get("command")
340            .and_then(|v| v.as_str())
341            .map(|s| truncate_str(s, 60)),
342        "search" | "grep" => input
343            .get("pattern")
344            .and_then(|v| v.as_str())
345            .map(|s| format!("\"{}\"", truncate_str(s, 30))),
346        "glob" => input
347            .get("pattern")
348            .and_then(|v| v.as_str())
349            .map(|s| truncate_str(s, 40)),
350        "ls" => input
351            .get("path")
352            .and_then(|v| v.as_str())
353            .map(|s| truncate_str(s, 50)),
354        "websearch" => input
355            .get("query")
356            .and_then(|v| v.as_str())
357            .map(|s| truncate_str(s, 40)),
358        "webfetch" => input
359            .get("url")
360            .and_then(|v| v.as_str())
361            .map(|s| truncate_str(s, 50)),
362        "task" => input
363            .get("description")
364            .and_then(|v| v.as_str())
365            .map(|s| truncate_str(s, 40)),
366        "task_create" => input
367            .get("description")
368            .and_then(|v| v.as_str())
369            .map(|s| truncate_str(s, 40)),
370        "task_get" | "task_stop" => input
371            .get("task_id")
372            .and_then(|v| v.as_str())
373            .map(|s| s.to_string()),
374        _ => None,
375    }
376}
377
378/// Truncate string at char boundary (using character count, not bytes)
379pub(crate) fn truncate_str(s: &str, max: usize) -> String {
380    truncate_chars(s, max)
381}