use crate::command::chat::storage::{ChatMessage, MessageRole};
use std::collections::HashMap;
pub const DEFAULT_OTHER_AGENT_TOOLCALL_THRESHOLD: usize = 5;
fn extract_agent_source(content: &str) -> Option<(String, &str)> {
let trimmed = content.trim_start();
if !trimmed.starts_with('<') {
return None;
}
let end_bracket = trimmed.find('>')?;
let agent_name = trimmed[1..end_bracket].to_string();
let remainder = &trimmed[end_bracket + 1..];
Some((agent_name, remainder))
}
fn is_tool_call_broadcast(content: &str) -> Option<(String, String)> {
let (agent_name, remainder) = extract_agent_source(content)?;
let trimmed = remainder.trim_start();
if !trimmed.starts_with("[调用工具 ") {
return None;
}
let end_bracket = trimmed.find(']')?;
let tool_name = trimmed["[调用工具 ".len()..end_bracket].to_string();
Some((agent_name, tool_name))
}
pub fn compress_other_agent_toolcalls(
messages: &[ChatMessage],
self_agent_name: &str,
threshold: usize,
) -> Vec<ChatMessage> {
if messages.is_empty() || threshold == 0 {
return messages.to_vec();
}
let other_agent_tool_calls: Vec<(usize, String, String)> = messages
.iter()
.enumerate()
.filter_map(|(idx, msg)| {
if msg.role != MessageRole::User {
return None;
}
let content = &msg.content;
let (agent_name, tool_name) = is_tool_call_broadcast(content)?;
if agent_name == self_agent_name {
return None;
}
Some((idx, agent_name, tool_name))
})
.collect();
if other_agent_tool_calls.is_empty() {
return messages.to_vec();
}
let agent_groups: HashMap<String, Vec<(usize, String)>> =
other_agent_tool_calls
.iter()
.fold(HashMap::new(), |mut acc, (idx, agent, tool)| {
acc.entry(agent.clone())
.or_default()
.push((*idx, tool.clone()));
acc
});
let mut indices_to_compress: Vec<usize> = Vec::new();
let mut summary_by_first_idx: HashMap<usize, (String, HashMap<String, usize>)> = HashMap::new();
for (agent_name, calls) in agent_groups {
let total = calls.len();
if total <= threshold {
continue;
}
let recent_start = total - threshold;
let (to_compress, _to_keep) = calls.split_at(recent_start);
for (idx, _) in to_compress {
indices_to_compress.push(*idx);
}
let tool_counts: HashMap<String, usize> =
to_compress
.iter()
.fold(HashMap::new(), |mut acc, (_, tool)| {
*acc.entry(tool.clone()).or_default() += 1;
acc
});
if let Some((first_idx, _)) = to_compress.first() {
summary_by_first_idx.insert(*first_idx, (agent_name.clone(), tool_counts));
}
}
let mut result: Vec<ChatMessage> = Vec::new();
for (idx, msg) in messages.iter().enumerate() {
if let Some((agent_name, tool_counts)) = summary_by_first_idx.get(&idx) {
let total_calls: usize = tool_counts.values().sum();
let tools_summary: String = tool_counts
.iter()
.map(|(tool, count)| format!("{}×{}", tool, count))
.collect::<Vec<_>>()
.join(", ");
let summary_content = format!(
"<{}> [早期工具调用摘要: {}, 共 {} 次]",
agent_name, tools_summary, total_calls
);
result.push(ChatMessage::text(MessageRole::User, summary_content));
}
if indices_to_compress.contains(&idx) {
continue;
}
result.push(msg.clone());
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_agent_source() {
assert_eq!(
extract_agent_source("<Frontend> hello"),
Some(("Frontend".to_string(), " hello"))
);
assert_eq!(
extract_agent_source(" <Backend> [调用工具 Read]"),
Some(("Backend".to_string(), " [调用工具 Read]"))
);
assert_eq!(extract_agent_source("no prefix"), None);
assert_eq!(extract_agent_source("<no-close"), None);
}
#[test]
fn test_is_tool_call_broadcast() {
assert_eq!(
is_tool_call_broadcast("<Frontend> [调用工具 Read]"),
Some(("Frontend".to_string(), "Read".to_string()))
);
assert_eq!(
is_tool_call_broadcast("<Backend> [调用工具 Edit] "),
Some(("Backend".to_string(), "Edit".to_string()))
);
assert_eq!(is_tool_call_broadcast("<Frontend> hello world"), None);
assert_eq!(is_tool_call_broadcast("regular message"), None);
}
#[test]
fn test_compress_no_other_agent() {
let messages = vec![
ChatMessage::text(MessageRole::User, "user question".to_string()),
ChatMessage::text(MessageRole::Assistant, "response".to_string()),
];
let result = compress_other_agent_toolcalls(&messages, "Main", 5);
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "user question");
assert_eq!(result[1].content, "response");
}
#[test]
fn test_compress_within_threshold() {
let messages = vec![
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Read]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Edit]".to_string()),
];
let result = compress_other_agent_toolcalls(&messages, "Backend", 5);
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "<Frontend> [调用工具 Read]");
assert_eq!(result[1].content, "<Frontend> [调用工具 Edit]");
}
#[test]
fn test_compress_exceed_threshold() {
let messages = vec![
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Read]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Edit]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Bash]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Read]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Edit]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Bash]".to_string()),
];
let result = compress_other_agent_toolcalls(&messages, "Backend", 3);
assert_eq!(result.len(), 4);
assert!(result[0].content.contains("[早期工具调用摘要"));
assert!(result[0].content.contains("Read×1"));
assert!(result[0].content.contains("Edit×1"));
assert!(result[0].content.contains("Bash×1"));
assert!(result[0].content.contains("共 3 次"));
assert_eq!(result[1].content, "<Frontend> [调用工具 Read]");
assert_eq!(result[2].content, "<Frontend> [调用工具 Edit]");
assert_eq!(result[3].content, "<Frontend> [调用工具 Bash]");
}
#[test]
fn test_compress_multiple_agents() {
let messages = vec![
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Read]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Edit]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Bash]".to_string()),
ChatMessage::text(MessageRole::User, "<Backend> [调用工具 Bash]".to_string()),
ChatMessage::text(MessageRole::User, "<Backend> [调用工具 Edit]".to_string()),
ChatMessage::text(MessageRole::User, "<Backend> [调用工具 Read]".to_string()),
];
let result = compress_other_agent_toolcalls(&messages, "DevOps", 2);
assert_eq!(result.len(), 6);
assert!(result[0].content.contains("<Frontend>"));
assert!(result[0].content.contains("[早期工具调用摘要"));
let backend_summary_idx = result
.iter()
.position(|m| m.content.contains("<Backend> [早期工具调用摘要"))
.expect("Backend summary should exist");
assert!(result[backend_summary_idx].content.contains("Bash×1"));
}
#[test]
fn test_compress_preserves_other_messages() {
let messages = vec![
ChatMessage::text(MessageRole::User, "user question".to_string()),
ChatMessage::text(MessageRole::Assistant, "assistant response".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Read]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Edit]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Bash]".to_string()),
ChatMessage::text(MessageRole::User, "<Frontend> [调用工具 Read]".to_string()),
ChatMessage::text(MessageRole::User, "another user message".to_string()),
];
let result = compress_other_agent_toolcalls(&messages, "Backend", 2);
assert!(result.iter().any(|m| m.content == "user question"));
assert!(result.iter().any(|m| m.content == "assistant response"));
assert!(result.iter().any(|m| m.content == "another user message"));
}
}