1use super::compact::is_exempt_tool;
12use super::policy::ContextTier;
13use crate::constants::{
14 WINDOW_KEEP_RECENT_MULTIPLIER, WINDOW_QUOTA_ASST_TEXT, WINDOW_QUOTA_TOOL_GROUP,
15 WINDOW_QUOTA_USER,
16};
17use crate::storage::{ChatMessage, MessageRole};
18use crate::util::log::write_info_log;
19
20const SIMPLE_CHARS_PER_TOKEN: usize = 3;
22
23const TOKEN_K_MULTIPLIER: usize = 1000;
25
26#[derive(Debug, Clone)]
30enum MessageUnit {
31 System { message_index: usize },
33 User { message_index: usize },
35 AssistantText { message_index: usize },
37 ToolGroup {
39 assistant_message_index: usize,
41 tool_result_indices: Vec<usize>,
43 },
44}
45
46impl MessageUnit {
47 fn priority(&self) -> u8 {
56 match self {
57 MessageUnit::System { .. } => ContextTier::System.priority(),
58 MessageUnit::User { .. } => ContextTier::User.priority(),
59 MessageUnit::AssistantText { .. } => ContextTier::Assistant.priority(),
60 MessageUnit::ToolGroup { .. } => ContextTier::RegularTool.priority(),
61 }
62 }
63
64 fn msg_count(&self) -> usize {
66 match self {
67 MessageUnit::System { .. }
68 | MessageUnit::User { .. }
69 | MessageUnit::AssistantText { .. } => 1,
70 MessageUnit::ToolGroup {
71 tool_result_indices,
72 ..
73 } => 1 + tool_result_indices.len(),
74 }
75 }
76
77 fn first_idx(&self) -> usize {
79 match self {
80 MessageUnit::System { message_index }
81 | MessageUnit::User { message_index }
82 | MessageUnit::AssistantText { message_index } => *message_index,
83 MessageUnit::ToolGroup {
84 assistant_message_index,
85 ..
86 } => *assistant_message_index,
87 }
88 }
89
90 fn estimate_tokens(&self, messages: &[ChatMessage]) -> usize {
92 let total_chars: usize = match self {
93 MessageUnit::System { message_index }
94 | MessageUnit::User { message_index }
95 | MessageUnit::AssistantText { message_index } => {
96 messages[*message_index].content.chars().count()
97 }
98 MessageUnit::ToolGroup {
99 assistant_message_index,
100 tool_result_indices,
101 } => {
102 let mut chars = messages[*assistant_message_index].content.chars().count();
103 for &result_index in tool_result_indices {
104 chars += messages[result_index].content.chars().count();
105 }
106 if let Some(ref tcs) = messages[*assistant_message_index].tool_calls {
107 for tc in tcs {
108 chars += tc.name.chars().count() + tc.arguments.chars().count();
109 }
110 }
111 chars
112 }
113 };
114 total_chars / SIMPLE_CHARS_PER_TOKEN
115 }
116
117 fn has_exempt_tool(&self, messages: &[ChatMessage], exempt_tools: &[String]) -> bool {
119 match self {
120 MessageUnit::ToolGroup {
121 assistant_message_index,
122 ..
123 } => messages[*assistant_message_index]
124 .tool_calls
125 .as_ref()
126 .map(|tcs| tcs.iter().any(|tc| is_exempt_tool(&tc.name, exempt_tools)))
127 .unwrap_or(false),
128 _ => false,
129 }
130 }
131}
132
133fn parse_message_units(messages: &[ChatMessage]) -> Vec<MessageUnit> {
137 let mut units = Vec::with_capacity(messages.len());
138 let mut i = 0;
139
140 while i < messages.len() {
141 let msg = &messages[i];
142
143 if msg.role == MessageRole::System {
144 units.push(MessageUnit::System { message_index: i });
145 i += 1;
146 } else if msg.role == MessageRole::User {
147 units.push(MessageUnit::User { message_index: i });
148 i += 1;
149 } else if msg.role == MessageRole::Assistant {
150 if msg.tool_calls.is_some() {
151 let assistant_message_index = i;
153 let mut tool_result_indices = Vec::new(); i += 1;
155 while i < messages.len() && messages[i].role == MessageRole::Tool {
156 tool_result_indices.push(i);
157 i += 1;
158 }
159 units.push(MessageUnit::ToolGroup {
160 assistant_message_index,
161 tool_result_indices,
162 });
163 } else {
164 units.push(MessageUnit::AssistantText { message_index: i });
166 i += 1;
167 }
168 } else if msg.role == MessageRole::Tool {
169 let start = i;
172 let mut tool_result_indices = vec![i];
173 i += 1;
174 while i < messages.len() && messages[i].role == MessageRole::Tool {
175 tool_result_indices.push(i);
176 i += 1;
177 }
178 units.push(MessageUnit::ToolGroup {
180 assistant_message_index: start, tool_result_indices,
182 });
183 } else {
184 units.push(MessageUnit::System { message_index: i });
186 i += 1;
187 }
188 }
189
190 units
191}
192
193struct SelectionResult {
197 retained: Vec<bool>,
199}
200
201struct SelectUnitsParams<'a> {
205 max_history_messages: usize,
206 max_context_tokens: usize,
207 keep_recent: usize,
208 exempt_tools: &'a [String],
209}
210
211fn select_units(
216 units: &[MessageUnit],
217 messages: &[ChatMessage],
218 params: &SelectUnitsParams,
219) -> SelectionResult {
220 let max_history_messages = params.max_history_messages;
222 let max_context_tokens = params.max_context_tokens;
223 let keep_recent = params.keep_recent;
224 let exempt_tools = params.exempt_tools;
225 let mut retained_flags = vec![false; units.len()];
226 let mut used_message_count = 0usize;
227 let mut used_token_count = 0usize;
228
229 let try_retain_unit = |message_index: usize,
231 retained: &mut [bool],
232 used_message_count: &mut usize,
233 used_token_count: &mut usize|
234 -> bool {
235 if retained[message_index] {
236 return false;
237 }
238 let unit = &units[message_index];
239 let unit_msg_count = unit.msg_count();
240 let unit_tokens = unit.estimate_tokens(messages);
241 if *used_message_count + unit_msg_count > max_history_messages
242 || *used_token_count + unit_tokens > max_context_tokens
243 {
244 return false;
245 }
246 retained[message_index] = true;
247 *used_message_count += unit_msg_count;
248 *used_token_count += unit_tokens;
249 true
250 };
251
252 for (i, unit) in units.iter().enumerate() {
254 if matches!(unit, MessageUnit::System { .. }) {
255 retained_flags[i] = true;
257 used_message_count += unit.msg_count();
258 used_token_count += unit.estimate_tokens(messages);
259 }
260 }
261
262 let recent_units_to_keep = keep_recent.saturating_mul(WINDOW_KEEP_RECENT_MULTIPLIER);
264 let mut stage1_retained_count = 0usize;
265 for i in (0..units.len()).rev() {
266 if stage1_retained_count >= recent_units_to_keep {
267 break;
268 }
269 if matches!(units[i], MessageUnit::System { .. }) {
270 continue;
271 }
272 if try_retain_unit(
273 i,
274 &mut retained_flags,
275 &mut used_message_count,
276 &mut used_token_count,
277 ) {
278 stage1_retained_count += 1;
279 } else {
280 break;
282 }
283 }
284
285 for i in (0..units.len()).rev() {
287 if retained_flags[i] {
288 continue;
289 }
290 if units[i].has_exempt_tool(messages, exempt_tools) {
291 try_retain_unit(
292 i,
293 &mut retained_flags,
294 &mut used_message_count,
295 &mut used_token_count,
296 );
297 }
298 }
299
300 let remaining_msgs = max_history_messages.saturating_sub(used_message_count);
302 let remaining_toks = max_context_tokens.saturating_sub(used_token_count);
303
304 let quotas: [(u8, f32); 3] = [
307 (ContextTier::User.priority(), WINDOW_QUOTA_USER),
308 (ContextTier::Assistant.priority(), WINDOW_QUOTA_ASST_TEXT),
309 (ContextTier::RegularTool.priority(), WINDOW_QUOTA_TOOL_GROUP),
310 ];
311
312 for (tier_prio, ratio) in quotas {
313 let tier_message_budget = ((remaining_msgs as f32) * ratio) as usize;
315 let tier_token_budget = ((remaining_toks as f32) * ratio) as usize;
316 let tier_start_msg_count = used_message_count;
317 let tier_start_token_count = used_token_count;
318
319 let mut tier_candidates: Vec<usize> = (0..units.len())
321 .filter(|&i| !retained_flags[i] && units[i].priority() == tier_prio)
322 .collect();
323 tier_candidates.sort_by(|&a, &b| units[b].first_idx().cmp(&units[a].first_idx()));
324
325 for idx in tier_candidates {
326 let unit = &units[idx];
327 let unit_msg_count = unit.msg_count();
328 let unit_tokens = unit.estimate_tokens(messages);
329 if used_message_count - tier_start_msg_count + unit_msg_count > tier_message_budget {
331 continue;
332 }
333 if used_token_count - tier_start_token_count + unit_tokens > tier_token_budget {
334 continue;
335 }
336 try_retain_unit(
337 idx,
338 &mut retained_flags,
339 &mut used_message_count,
340 &mut used_token_count,
341 );
342 }
343 }
344
345 for i in (0..units.len()).rev() {
347 try_retain_unit(
348 i,
349 &mut retained_flags,
350 &mut used_message_count,
351 &mut used_token_count,
352 );
353 }
354
355 let has_user_retained = units
357 .iter()
358 .enumerate()
359 .any(|(i, u)| matches!(u, MessageUnit::User { .. }) && retained_flags[i]);
360 if !has_user_retained
361 && let Some(last_user_idx) = (0..units.len())
362 .rev()
363 .find(|&i| matches!(units[i], MessageUnit::User { .. }))
364 {
365 retained_flags[last_user_idx] = true;
366 }
367
368 SelectionResult {
369 retained: retained_flags,
370 }
371}
372
373fn tool_names_of(unit: &MessageUnit, messages: &[ChatMessage]) -> Vec<String> {
377 match unit {
378 MessageUnit::ToolGroup {
379 assistant_message_index,
380 ..
381 } => messages[*assistant_message_index]
382 .tool_calls
383 .as_ref()
384 .map(|tcs| tcs.iter().map(|tc| tc.name.clone()).collect())
385 .unwrap_or_default(),
386 _ => Vec::new(),
387 }
388}
389
390fn merged_placeholder(names: &[String]) -> ChatMessage {
393 let content = if names.is_empty() {
394 "[Previous tool calls dropped]".to_string()
395 } else {
396 format!("[Previous: used {}]", names.join(", "))
397 };
398 ChatMessage::text(MessageRole::Assistant, content)
399}
400
401pub fn select_messages(
414 messages: &[ChatMessage],
415 max_history_messages: usize,
416 max_context_tokens_k: usize,
417 keep_recent: usize,
418 exempt_tools: &[String],
419) -> Vec<ChatMessage> {
420 let max_msgs = if max_history_messages == 0 {
421 usize::MAX
422 } else {
423 max_history_messages
424 };
425 let max_tokens = if max_context_tokens_k == 0 {
426 usize::MAX
427 } else {
428 max_context_tokens_k * TOKEN_K_MULTIPLIER
429 };
430
431 let total_tokens = estimate_tokens_simple(messages);
432 if messages.len() <= max_msgs && total_tokens <= max_tokens {
433 return messages.to_vec();
434 }
435
436 let units = parse_message_units(messages);
437 let selection = select_units(
438 &units,
439 messages,
440 &SelectUnitsParams {
441 max_history_messages: max_msgs,
442 max_context_tokens: max_tokens,
443 keep_recent,
444 exempt_tools,
445 },
446 );
447
448 let mut result = Vec::with_capacity(messages.len());
450 let mut pending_dropped_names: Vec<String> = Vec::new(); let flush_pending = |pending: &mut Vec<String>, out: &mut Vec<ChatMessage>| {
453 if !pending.is_empty() {
454 out.push(merged_placeholder(pending));
455 pending.clear();
456 }
457 };
458
459 for (i, unit) in units.iter().enumerate() {
460 if selection.retained[i] {
461 flush_pending(&mut pending_dropped_names, &mut result);
462 match unit {
463 MessageUnit::System { message_index }
464 | MessageUnit::User { message_index }
465 | MessageUnit::AssistantText { message_index } => {
466 result.push(messages[*message_index].clone());
467 }
468 MessageUnit::ToolGroup {
469 assistant_message_index,
470 tool_result_indices,
471 } => {
472 result.push(messages[*assistant_message_index].clone());
473 for &result_index in tool_result_indices {
474 result.push(messages[result_index].clone());
475 }
476 }
477 }
478 } else if matches!(unit, MessageUnit::ToolGroup { .. }) {
479 pending_dropped_names.extend(tool_names_of(unit, messages));
481 }
482 }
484 flush_pending(&mut pending_dropped_names, &mut result);
485
486 let dropped_count = selection.retained.iter().filter(|&&r| !r).count();
487 if dropped_count > 0 {
488 write_info_log(
489 "window_select",
490 &format!(
491 "三阶段窗口选择: 保留 {}/{} 单元, 丢弃 {} (tokens: {}→{}, keep_recent={})",
492 units.len() - dropped_count,
493 units.len(),
494 dropped_count,
495 total_tokens,
496 estimate_tokens_simple(&result),
497 keep_recent,
498 ),
499 );
500 }
501
502 result
503}
504
505fn estimate_tokens_simple(messages: &[ChatMessage]) -> usize {
507 let total_chars: usize = messages
508 .iter()
509 .map(|m| {
510 let mut chars = m.content.chars().count();
511 if let Some(ref tcs) = m.tool_calls {
512 for tc in tcs {
513 chars += tc.name.chars().count() + tc.arguments.chars().count();
514 }
515 }
516 chars
517 })
518 .sum();
519 total_chars / 3
520}
521
522#[cfg(test)]
523mod tests;