use super::super::constants::{
ROLE_ASSISTANT, ROLE_SYSTEM, ROLE_TOOL, ROLE_USER, WINDOW_KEEP_RECENT_MULTIPLIER,
WINDOW_QUOTA_ASST_TEXT, WINDOW_QUOTA_TOOL_GROUP, WINDOW_QUOTA_USER,
};
use super::super::storage::ChatMessage;
use super::compact::is_exempt_tool;
use crate::util::log::write_info_log;
#[derive(Debug, Clone)]
enum MessageUnit {
System { message_index: usize },
User { message_index: usize },
AssistantText { message_index: usize },
ToolGroup {
assistant_message_index: usize,
tool_result_indices: Vec<usize>,
},
}
impl MessageUnit {
fn priority(&self) -> u8 {
match self {
MessageUnit::System { .. } => 0,
MessageUnit::User { .. } => 1,
MessageUnit::AssistantText { .. } => 2,
MessageUnit::ToolGroup { .. } => 3,
}
}
fn msg_count(&self) -> usize {
match self {
MessageUnit::System { .. }
| MessageUnit::User { .. }
| MessageUnit::AssistantText { .. } => 1,
MessageUnit::ToolGroup {
tool_result_indices,
..
} => 1 + tool_result_indices.len(),
}
}
fn first_idx(&self) -> usize {
match self {
MessageUnit::System { message_index }
| MessageUnit::User { message_index }
| MessageUnit::AssistantText { message_index } => *message_index,
MessageUnit::ToolGroup {
assistant_message_index,
..
} => *assistant_message_index,
}
}
fn estimate_tokens(&self, messages: &[ChatMessage]) -> usize {
let total_chars: usize = match self {
MessageUnit::System { message_index }
| MessageUnit::User { message_index }
| MessageUnit::AssistantText { message_index } => {
messages[*message_index].content.chars().count()
}
MessageUnit::ToolGroup {
assistant_message_index,
tool_result_indices,
} => {
let mut chars = messages[*assistant_message_index].content.chars().count();
for &result_index in tool_result_indices {
chars += messages[result_index].content.chars().count();
}
if let Some(ref tcs) = messages[*assistant_message_index].tool_calls {
for tc in tcs {
chars += tc.name.chars().count() + tc.arguments.chars().count();
}
}
chars
}
};
total_chars / 3
}
fn has_exempt_tool(&self, messages: &[ChatMessage], exempt_tools: &[String]) -> bool {
match self {
MessageUnit::ToolGroup {
assistant_message_index,
..
} => messages[*assistant_message_index]
.tool_calls
.as_ref()
.map(|tcs| tcs.iter().any(|tc| is_exempt_tool(&tc.name, exempt_tools)))
.unwrap_or(false),
_ => false,
}
}
}
fn parse_message_units(messages: &[ChatMessage]) -> Vec<MessageUnit> {
let mut units = Vec::with_capacity(messages.len());
let mut i = 0;
while i < messages.len() {
let msg = &messages[i];
if msg.role == ROLE_SYSTEM {
units.push(MessageUnit::System { message_index: i });
i += 1;
} else if msg.role == ROLE_USER {
units.push(MessageUnit::User { message_index: i });
i += 1;
} else if msg.role == ROLE_ASSISTANT {
if msg.tool_calls.is_some() {
let assistant_message_index = i;
let mut tool_result_indices = Vec::new();
i += 1;
while i < messages.len() && messages[i].role == ROLE_TOOL {
tool_result_indices.push(i);
i += 1;
}
units.push(MessageUnit::ToolGroup {
assistant_message_index,
tool_result_indices,
});
} else {
units.push(MessageUnit::AssistantText { message_index: i });
i += 1;
}
} else if msg.role == ROLE_TOOL {
let start = i;
let mut tool_result_indices = vec![i];
i += 1;
while i < messages.len() && messages[i].role == ROLE_TOOL {
tool_result_indices.push(i);
i += 1;
}
units.push(MessageUnit::ToolGroup {
assistant_message_index: start, tool_result_indices,
});
} else {
units.push(MessageUnit::System { message_index: i });
i += 1;
}
}
units
}
struct SelectionResult {
retained: Vec<bool>,
}
fn select_units(
units: &[MessageUnit],
messages: &[ChatMessage],
max_history_messages: usize,
max_context_tokens: usize,
keep_recent: usize,
exempt_tools: &[String],
) -> SelectionResult {
let mut retained_flags = vec![false; units.len()];
let mut used_message_count = 0usize;
let mut used_token_count = 0usize;
let try_retain_unit = |message_index: usize,
retained: &mut [bool],
used_message_count: &mut usize,
used_token_count: &mut usize|
-> bool {
if retained[message_index] {
return false;
}
let unit = &units[message_index];
let unit_msg_count = unit.msg_count();
let unit_tokens = unit.estimate_tokens(messages);
if *used_message_count + unit_msg_count > max_history_messages
|| *used_token_count + unit_tokens > max_context_tokens
{
return false;
}
retained[message_index] = true;
*used_message_count += unit_msg_count;
*used_token_count += unit_tokens;
true
};
for (i, unit) in units.iter().enumerate() {
if matches!(unit, MessageUnit::System { .. }) {
retained_flags[i] = true;
used_message_count += unit.msg_count();
used_token_count += unit.estimate_tokens(messages);
}
}
let recent_units_to_keep = keep_recent.saturating_mul(WINDOW_KEEP_RECENT_MULTIPLIER);
let mut stage1_retained_count = 0usize;
for i in (0..units.len()).rev() {
if stage1_retained_count >= recent_units_to_keep {
break;
}
if matches!(units[i], MessageUnit::System { .. }) {
continue;
}
if try_retain_unit(
i,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
) {
stage1_retained_count += 1;
} else {
break;
}
}
for i in (0..units.len()).rev() {
if retained_flags[i] {
continue;
}
if units[i].has_exempt_tool(messages, exempt_tools) {
try_retain_unit(
i,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
);
}
}
let remaining_msgs = max_history_messages.saturating_sub(used_message_count);
let remaining_toks = max_context_tokens.saturating_sub(used_token_count);
let quotas: [(u8, f32); 3] = [
(1, WINDOW_QUOTA_USER), (2, WINDOW_QUOTA_ASST_TEXT), (3, WINDOW_QUOTA_TOOL_GROUP), ];
for (tier_prio, ratio) in quotas {
let tier_message_budget = ((remaining_msgs as f32) * ratio) as usize;
let tier_token_budget = ((remaining_toks as f32) * ratio) as usize;
let tier_start_msg_count = used_message_count;
let tier_start_token_count = used_token_count;
let mut tier_candidates: Vec<usize> = (0..units.len())
.filter(|&i| !retained_flags[i] && units[i].priority() == tier_prio)
.collect();
tier_candidates.sort_by(|&a, &b| units[b].first_idx().cmp(&units[a].first_idx()));
for idx in tier_candidates {
let unit = &units[idx];
let unit_msg_count = unit.msg_count();
let unit_tokens = unit.estimate_tokens(messages);
if used_message_count - tier_start_msg_count + unit_msg_count > tier_message_budget {
continue;
}
if used_token_count - tier_start_token_count + unit_tokens > tier_token_budget {
continue;
}
try_retain_unit(
idx,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
);
}
}
for i in (0..units.len()).rev() {
try_retain_unit(
i,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
);
}
let has_user_retained = units
.iter()
.enumerate()
.any(|(i, u)| matches!(u, MessageUnit::User { .. }) && retained_flags[i]);
if !has_user_retained
&& let Some(last_user_idx) = (0..units.len())
.rev()
.find(|&i| matches!(units[i], MessageUnit::User { .. }))
{
retained_flags[last_user_idx] = true;
}
SelectionResult {
retained: retained_flags,
}
}
fn tool_names_of(unit: &MessageUnit, messages: &[ChatMessage]) -> Vec<String> {
match unit {
MessageUnit::ToolGroup {
assistant_message_index,
..
} => messages[*assistant_message_index]
.tool_calls
.as_ref()
.map(|tcs| tcs.iter().map(|tc| tc.name.clone()).collect())
.unwrap_or_default(),
_ => Vec::new(),
}
}
fn merged_placeholder(names: &[String]) -> ChatMessage {
let content = if names.is_empty() {
"[Previous tool calls dropped]".to_string()
} else {
format!("[Previous: used {}]", names.join(", "))
};
ChatMessage {
role: ROLE_ASSISTANT.to_string(),
content,
tool_calls: None,
tool_call_id: None,
images: None,
}
}
pub fn select_messages(
messages: &[ChatMessage],
max_history_messages: usize,
max_context_tokens_k: usize,
keep_recent: usize,
exempt_tools: &[String],
) -> Vec<ChatMessage> {
let max_msgs = if max_history_messages == 0 {
usize::MAX
} else {
max_history_messages
};
let max_tokens = if max_context_tokens_k == 0 {
usize::MAX
} else {
max_context_tokens_k * 1000
};
let total_tokens = estimate_tokens_simple(messages);
if messages.len() <= max_msgs && total_tokens <= max_tokens {
return messages.to_vec();
}
let units = parse_message_units(messages);
let selection = select_units(
&units,
messages,
max_msgs,
max_tokens,
keep_recent,
exempt_tools,
);
let mut result = Vec::new();
let mut pending_dropped_names: Vec<String> = Vec::new();
let flush_pending = |pending: &mut Vec<String>, out: &mut Vec<ChatMessage>| {
if !pending.is_empty() {
out.push(merged_placeholder(pending));
pending.clear();
}
};
for (i, unit) in units.iter().enumerate() {
if selection.retained[i] {
flush_pending(&mut pending_dropped_names, &mut result);
match unit {
MessageUnit::System { message_index }
| MessageUnit::User { message_index }
| MessageUnit::AssistantText { message_index } => {
result.push(messages[*message_index].clone());
}
MessageUnit::ToolGroup {
assistant_message_index,
tool_result_indices,
} => {
result.push(messages[*assistant_message_index].clone());
for &result_index in tool_result_indices {
result.push(messages[result_index].clone());
}
}
}
} else if matches!(unit, MessageUnit::ToolGroup { .. }) {
pending_dropped_names.extend(tool_names_of(unit, messages));
}
}
flush_pending(&mut pending_dropped_names, &mut result);
let dropped_count = selection.retained.iter().filter(|&&r| !r).count();
if dropped_count > 0 {
write_info_log(
"window_select",
&format!(
"三阶段窗口选择: 保留 {}/{} 单元, 丢弃 {} (tokens: {}→{}, keep_recent={})",
units.len() - dropped_count,
units.len(),
dropped_count,
total_tokens,
estimate_tokens_simple(&result),
keep_recent,
),
);
}
result
}
fn estimate_tokens_simple(messages: &[ChatMessage]) -> usize {
let total_chars: usize = messages
.iter()
.map(|m| {
let mut chars = m.content.chars().count();
if let Some(ref tcs) = m.tool_calls {
for tc in tcs {
chars += tc.name.chars().count() + tc.arguments.chars().count();
}
}
chars
})
.sum();
total_chars / 3
}
#[cfg(test)]
mod tests {
use super::*;
fn user_msg(content: &str) -> ChatMessage {
ChatMessage {
role: ROLE_USER.to_string(),
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
images: None,
}
}
fn assistant_msg(content: &str) -> ChatMessage {
ChatMessage {
role: ROLE_ASSISTANT.to_string(),
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
images: None,
}
}
fn tool_call_msg(names: &[&str]) -> ChatMessage {
ChatMessage {
role: ROLE_ASSISTANT.to_string(),
content: String::new(),
tool_calls: Some(
names
.iter()
.enumerate()
.map(|(i, name)| super::super::super::storage::ToolCallItem {
id: format!("call_{}", i),
name: name.to_string(),
arguments: "{}".to_string(),
})
.collect(),
),
tool_call_id: None,
images: None,
}
}
fn tool_result_msg(call_id: &str, content: &str) -> ChatMessage {
ChatMessage {
role: ROLE_TOOL.to_string(),
content: content.to_string(),
tool_calls: None,
tool_call_id: Some(call_id.to_string()),
images: None,
}
}
#[test]
fn test_no_truncation_needed() {
let msgs = vec![user_msg("hello"), assistant_msg("hi")];
let result = select_messages(&msgs, 100, 0, 10, &[]); assert_eq!(result.len(), 2);
assert_eq!(result[0].role, ROLE_USER);
assert_eq!(result[1].role, ROLE_ASSISTANT);
}
#[test]
fn test_tool_group_dropped_first() {
let msgs = vec![
user_msg("do something"),
assistant_msg("let me check"),
tool_call_msg(&["Shell"]),
tool_result_msg("call_0", &"huge output ".repeat(1000)),
user_msg("what about this"),
assistant_msg("here's the answer"),
];
let result = select_messages(&msgs, 100, 1, 0, &[]);
assert!(result.iter().any(|m| m.role == ROLE_USER));
assert!(!result.iter().any(|m| m.role == ROLE_TOOL)); assert!(result.iter().any(|m| m.content.contains("Previous: used"))); }
#[test]
fn test_time_order_preserved() {
let msgs = vec![
user_msg("first"),
assistant_msg("ok1"),
tool_call_msg(&["Shell"]),
tool_result_msg("call_0", "output"),
user_msg("second"),
assistant_msg("ok2"),
];
let result = select_messages(&msgs, 100, 0, 10, &[]);
let user_positions: Vec<usize> = result
.iter()
.enumerate()
.filter(|(_, m)| m.role == ROLE_USER)
.map(|(i, _)| i)
.collect();
assert!(user_positions[0] < user_positions[1]);
}
#[test]
fn test_placeholder_format() {
let msgs = vec![
user_msg("run"),
tool_call_msg(&["Shell", "Read"]),
tool_result_msg("call_0", &"x".repeat(2000)),
tool_result_msg("call_1", &"y".repeat(2000)),
];
let result = select_messages(&msgs, 100, 1, 0, &[]);
let placeholder = result.iter().find(|m| m.content.contains("Previous: used"));
assert!(placeholder.is_some());
let p = placeholder.unwrap();
assert!(p.content.contains("Shell"));
assert!(p.content.contains("Read"));
assert!(p.tool_calls.is_none());
}
#[test]
fn test_exempt_tool_group_protected() {
let msgs = vec![
user_msg("load a skill"),
tool_call_msg(&["LoadSkill"]),
tool_result_msg("call_0", &"skill content ".repeat(500)),
user_msg("q1"),
assistant_msg("a1"),
user_msg("q2"),
assistant_msg("a2"),
user_msg("q3"),
assistant_msg("a3"),
];
let result = select_messages(&msgs, 100, 5, 0, &[]);
assert!(
result.iter().any(|m| m.role == ROLE_TOOL),
"exempt tool result 应该被保留"
);
}
#[test]
fn test_stage1_time_fallback_keeps_recent_tool_group() {
let mut msgs = Vec::new();
for i in 0..20 {
msgs.push(user_msg(&format!("old user {}", i).repeat(50)));
}
msgs.push(tool_call_msg(&["Shell"]));
msgs.push(tool_result_msg("call_0", "recent shell output"));
msgs.push(user_msg("latest"));
let result = select_messages(&msgs, 100, 2, 2, &[]);
assert!(
result.iter().any(|m| m.role == ROLE_TOOL),
"最近的 tool result 应该被时间保底保留"
);
assert!(
result
.iter()
.any(|m| m.role == ROLE_USER && m.content == "latest"),
"最新 User 必须保留"
);
}
}