use super::compact::is_exempt_tool;
use super::policy::ContextTier;
use crate::command::chat::constants::{
WINDOW_KEEP_RECENT_MULTIPLIER, WINDOW_QUOTA_ASST_TEXT, WINDOW_QUOTA_TOOL_GROUP,
WINDOW_QUOTA_USER,
};
use crate::command::chat::storage::{ChatMessage, MessageRole};
use crate::util::log::write_info_log;
const SIMPLE_CHARS_PER_TOKEN: usize = 3;
const TOKEN_K_MULTIPLIER: usize = 1000;
#[derive(Debug, Clone)]
enum MessageUnit {
System { message_index: usize },
User { message_index: usize },
AssistantText { message_index: usize },
ToolGroup {
assistant_message_index: usize,
tool_result_indices: Vec<usize>,
},
}
impl MessageUnit {
fn priority(&self) -> u8 {
match self {
MessageUnit::System { .. } => ContextTier::System.priority(),
MessageUnit::User { .. } => ContextTier::User.priority(),
MessageUnit::AssistantText { .. } => ContextTier::Assistant.priority(),
MessageUnit::ToolGroup { .. } => ContextTier::RegularTool.priority(),
}
}
fn msg_count(&self) -> usize {
match self {
MessageUnit::System { .. }
| MessageUnit::User { .. }
| MessageUnit::AssistantText { .. } => 1,
MessageUnit::ToolGroup {
tool_result_indices,
..
} => 1 + tool_result_indices.len(),
}
}
fn first_idx(&self) -> usize {
match self {
MessageUnit::System { message_index }
| MessageUnit::User { message_index }
| MessageUnit::AssistantText { message_index } => *message_index,
MessageUnit::ToolGroup {
assistant_message_index,
..
} => *assistant_message_index,
}
}
fn estimate_tokens(&self, messages: &[ChatMessage]) -> usize {
let total_chars: usize = match self {
MessageUnit::System { message_index }
| MessageUnit::User { message_index }
| MessageUnit::AssistantText { message_index } => {
messages[*message_index].content.chars().count()
}
MessageUnit::ToolGroup {
assistant_message_index,
tool_result_indices,
} => {
let mut chars = messages[*assistant_message_index].content.chars().count();
for &result_index in tool_result_indices {
chars += messages[result_index].content.chars().count();
}
if let Some(ref tcs) = messages[*assistant_message_index].tool_calls {
for tc in tcs {
chars += tc.name.chars().count() + tc.arguments.chars().count();
}
}
chars
}
};
total_chars / SIMPLE_CHARS_PER_TOKEN
}
fn has_exempt_tool(&self, messages: &[ChatMessage], exempt_tools: &[String]) -> bool {
match self {
MessageUnit::ToolGroup {
assistant_message_index,
..
} => messages[*assistant_message_index]
.tool_calls
.as_ref()
.map(|tcs| tcs.iter().any(|tc| is_exempt_tool(&tc.name, exempt_tools)))
.unwrap_or(false),
_ => false,
}
}
}
fn parse_message_units(messages: &[ChatMessage]) -> Vec<MessageUnit> {
let mut units = Vec::with_capacity(messages.len());
let mut i = 0;
while i < messages.len() {
let msg = &messages[i];
if msg.role == MessageRole::System {
units.push(MessageUnit::System { message_index: i });
i += 1;
} else if msg.role == MessageRole::User {
units.push(MessageUnit::User { message_index: i });
i += 1;
} else if msg.role == MessageRole::Assistant {
if msg.tool_calls.is_some() {
let assistant_message_index = i;
let mut tool_result_indices = Vec::new(); i += 1;
while i < messages.len() && messages[i].role == MessageRole::Tool {
tool_result_indices.push(i);
i += 1;
}
units.push(MessageUnit::ToolGroup {
assistant_message_index,
tool_result_indices,
});
} else {
units.push(MessageUnit::AssistantText { message_index: i });
i += 1;
}
} else if msg.role == MessageRole::Tool {
let start = i;
let mut tool_result_indices = vec![i];
i += 1;
while i < messages.len() && messages[i].role == MessageRole::Tool {
tool_result_indices.push(i);
i += 1;
}
units.push(MessageUnit::ToolGroup {
assistant_message_index: start, tool_result_indices,
});
} else {
units.push(MessageUnit::System { message_index: i });
i += 1;
}
}
units
}
struct SelectionResult {
retained: Vec<bool>,
}
fn select_units(
units: &[MessageUnit],
messages: &[ChatMessage],
max_history_messages: usize,
max_context_tokens: usize,
keep_recent: usize,
exempt_tools: &[String],
) -> SelectionResult {
let mut retained_flags = vec![false; units.len()];
let mut used_message_count = 0usize;
let mut used_token_count = 0usize;
let try_retain_unit = |message_index: usize,
retained: &mut [bool],
used_message_count: &mut usize,
used_token_count: &mut usize|
-> bool {
if retained[message_index] {
return false;
}
let unit = &units[message_index];
let unit_msg_count = unit.msg_count();
let unit_tokens = unit.estimate_tokens(messages);
if *used_message_count + unit_msg_count > max_history_messages
|| *used_token_count + unit_tokens > max_context_tokens
{
return false;
}
retained[message_index] = true;
*used_message_count += unit_msg_count;
*used_token_count += unit_tokens;
true
};
for (i, unit) in units.iter().enumerate() {
if matches!(unit, MessageUnit::System { .. }) {
retained_flags[i] = true;
used_message_count += unit.msg_count();
used_token_count += unit.estimate_tokens(messages);
}
}
let recent_units_to_keep = keep_recent.saturating_mul(WINDOW_KEEP_RECENT_MULTIPLIER);
let mut stage1_retained_count = 0usize;
for i in (0..units.len()).rev() {
if stage1_retained_count >= recent_units_to_keep {
break;
}
if matches!(units[i], MessageUnit::System { .. }) {
continue;
}
if try_retain_unit(
i,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
) {
stage1_retained_count += 1;
} else {
break;
}
}
for i in (0..units.len()).rev() {
if retained_flags[i] {
continue;
}
if units[i].has_exempt_tool(messages, exempt_tools) {
try_retain_unit(
i,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
);
}
}
let remaining_msgs = max_history_messages.saturating_sub(used_message_count);
let remaining_toks = max_context_tokens.saturating_sub(used_token_count);
let quotas: [(u8, f32); 3] = [
(ContextTier::User.priority(), WINDOW_QUOTA_USER),
(ContextTier::Assistant.priority(), WINDOW_QUOTA_ASST_TEXT),
(ContextTier::RegularTool.priority(), WINDOW_QUOTA_TOOL_GROUP),
];
for (tier_prio, ratio) in quotas {
let tier_message_budget = ((remaining_msgs as f32) * ratio) as usize;
let tier_token_budget = ((remaining_toks as f32) * ratio) as usize;
let tier_start_msg_count = used_message_count;
let tier_start_token_count = used_token_count;
let mut tier_candidates: Vec<usize> = (0..units.len())
.filter(|&i| !retained_flags[i] && units[i].priority() == tier_prio)
.collect();
tier_candidates.sort_by(|&a, &b| units[b].first_idx().cmp(&units[a].first_idx()));
for idx in tier_candidates {
let unit = &units[idx];
let unit_msg_count = unit.msg_count();
let unit_tokens = unit.estimate_tokens(messages);
if used_message_count - tier_start_msg_count + unit_msg_count > tier_message_budget {
continue;
}
if used_token_count - tier_start_token_count + unit_tokens > tier_token_budget {
continue;
}
try_retain_unit(
idx,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
);
}
}
for i in (0..units.len()).rev() {
try_retain_unit(
i,
&mut retained_flags,
&mut used_message_count,
&mut used_token_count,
);
}
let has_user_retained = units
.iter()
.enumerate()
.any(|(i, u)| matches!(u, MessageUnit::User { .. }) && retained_flags[i]);
if !has_user_retained
&& let Some(last_user_idx) = (0..units.len())
.rev()
.find(|&i| matches!(units[i], MessageUnit::User { .. }))
{
retained_flags[last_user_idx] = true;
}
SelectionResult {
retained: retained_flags,
}
}
fn tool_names_of(unit: &MessageUnit, messages: &[ChatMessage]) -> Vec<String> {
match unit {
MessageUnit::ToolGroup {
assistant_message_index,
..
} => messages[*assistant_message_index]
.tool_calls
.as_ref()
.map(|tcs| tcs.iter().map(|tc| tc.name.clone()).collect())
.unwrap_or_default(),
_ => Vec::new(),
}
}
fn merged_placeholder(names: &[String]) -> ChatMessage {
let content = if names.is_empty() {
"[Previous tool calls dropped]".to_string()
} else {
format!("[Previous: used {}]", names.join(", "))
};
ChatMessage::text(MessageRole::Assistant, content)
}
pub fn select_messages(
messages: &[ChatMessage],
max_history_messages: usize,
max_context_tokens_k: usize,
keep_recent: usize,
exempt_tools: &[String],
) -> Vec<ChatMessage> {
let max_msgs = if max_history_messages == 0 {
usize::MAX
} else {
max_history_messages
};
let max_tokens = if max_context_tokens_k == 0 {
usize::MAX
} else {
max_context_tokens_k * TOKEN_K_MULTIPLIER
};
let total_tokens = estimate_tokens_simple(messages);
if messages.len() <= max_msgs && total_tokens <= max_tokens {
return messages.to_vec();
}
let units = parse_message_units(messages);
let selection = select_units(
&units,
messages,
max_msgs,
max_tokens,
keep_recent,
exempt_tools,
);
let mut result = Vec::with_capacity(messages.len());
let mut pending_dropped_names: Vec<String> = Vec::new();
let flush_pending = |pending: &mut Vec<String>, out: &mut Vec<ChatMessage>| {
if !pending.is_empty() {
out.push(merged_placeholder(pending));
pending.clear();
}
};
for (i, unit) in units.iter().enumerate() {
if selection.retained[i] {
flush_pending(&mut pending_dropped_names, &mut result);
match unit {
MessageUnit::System { message_index }
| MessageUnit::User { message_index }
| MessageUnit::AssistantText { message_index } => {
result.push(messages[*message_index].clone());
}
MessageUnit::ToolGroup {
assistant_message_index,
tool_result_indices,
} => {
result.push(messages[*assistant_message_index].clone());
for &result_index in tool_result_indices {
result.push(messages[result_index].clone());
}
}
}
} else if matches!(unit, MessageUnit::ToolGroup { .. }) {
pending_dropped_names.extend(tool_names_of(unit, messages));
}
}
flush_pending(&mut pending_dropped_names, &mut result);
let dropped_count = selection.retained.iter().filter(|&&r| !r).count();
if dropped_count > 0 {
write_info_log(
"window_select",
&format!(
"三阶段窗口选择: 保留 {}/{} 单元, 丢弃 {} (tokens: {}→{}, keep_recent={})",
units.len() - dropped_count,
units.len(),
dropped_count,
total_tokens,
estimate_tokens_simple(&result),
keep_recent,
),
);
}
result
}
fn estimate_tokens_simple(messages: &[ChatMessage]) -> usize {
let total_chars: usize = messages
.iter()
.map(|m| {
let mut chars = m.content.chars().count();
if let Some(ref tcs) = m.tool_calls {
for tc in tcs {
chars += tc.name.chars().count() + tc.arguments.chars().count();
}
}
chars
})
.sum();
total_chars / 3
}
#[cfg(test)]
mod tests;