Skip to main content

j_agent/context/
window.rs

1//! 优先级消息窗口选择(三阶段 + 比例配额 + 溢出)
2//!
3//! 核心原则:
4//! 1. **时间保底**:最近 K 个 unit 无条件保留(K 与 micro_compact.keep_recent 对齐)
5//! 2. **豁免保底**:属于 EXEMPT_TOOLS 的 ToolGroup 优先保留(技能/任务上下文)
6//! 3. **比例配额**:剩余预算按比例分配给 User / AssistantText / ToolGroup,
7//!    而不是层层堆叠;某 tier 配额用不完时按时间倒序溢出到未保留 unit
8//!
9//! 输出顺序始终保持原始时间顺序;丢弃的 ToolGroup 用统一占位符替换。
10
11use super::compact::is_exempt_tool;
12use super::policy::ContextTier;
13use crate::constants::{
14    WINDOW_KEEP_RECENT_MULTIPLIER, WINDOW_QUOTA_ASST_TEXT, WINDOW_QUOTA_TOOL_GROUP,
15    WINDOW_QUOTA_USER,
16};
17use crate::storage::{ChatMessage, MessageRole};
18use crate::util::log::write_info_log;
19
20/// 简单 token 估算:每 3 个字符 ≈ 1 token
21const SIMPLE_CHARS_PER_TOKEN: usize = 3;
22
23/// token_K 到实际 token 数的换算乘数
24const TOKEN_K_MULTIPLIER: usize = 1000;
25
26// ========== MessageUnit 定义 ==========
27
28/// 消息分组 — 原子单元,要么全部保留,要么全部丢弃
29#[derive(Debug, Clone)]
30enum MessageUnit {
31    /// 系统消息,始终保留
32    System { message_index: usize },
33    /// 用户消息,最高优先级
34    User { message_index: usize },
35    /// Assistant 纯文字消息(有 content,无 tool_calls)
36    AssistantText { message_index: usize },
37    /// 工具调用组 — assistant(tool_calls) + 所有对应 tool result,原子单元
38    ToolGroup {
39        /// assistant(tool_calls) 消息的索引
40        assistant_message_index: usize,
41        /// 对应 tool result 消息的索引列表(紧跟在 assistant 后面)
42        tool_result_indices: Vec<usize>,
43    },
44}
45
46impl MessageUnit {
47    /// 消息单元的优先级(数值越小优先级越高)
48    ///
49    /// 数值与 `context::policy::ContextTier` 对齐:
50    /// - System=0, User=1, KeyTool=2, Assistant=3, RegularTool=4
51    ///
52    /// ToolGroup 的 tier 取组内 tool_call 的最高优先级(最小数值)。
53    /// Stage 2 的豁免保底已经把 KeyTool ToolGroup 单独保留,Stage 3 的比例
54    /// 配额只对未保留 unit 生效,故实际参与 Stage 3 筛选的 ToolGroup 通常是 RegularTool。
55    fn priority(&self) -> u8 {
56        match self {
57            MessageUnit::System { .. } => ContextTier::System.priority(),
58            MessageUnit::User { .. } => ContextTier::User.priority(),
59            MessageUnit::AssistantText { .. } => ContextTier::Assistant.priority(),
60            MessageUnit::ToolGroup { .. } => ContextTier::RegularTool.priority(),
61        }
62    }
63
64    /// 该单元包含的消息条数
65    fn msg_count(&self) -> usize {
66        match self {
67            MessageUnit::System { .. }
68            | MessageUnit::User { .. }
69            | MessageUnit::AssistantText { .. } => 1,
70            MessageUnit::ToolGroup {
71                tool_result_indices,
72                ..
73            } => 1 + tool_result_indices.len(),
74        }
75    }
76
77    /// 该单元中第一条消息的索引(用于时间排序)
78    fn first_idx(&self) -> usize {
79        match self {
80            MessageUnit::System { message_index }
81            | MessageUnit::User { message_index }
82            | MessageUnit::AssistantText { message_index } => *message_index,
83            MessageUnit::ToolGroup {
84                assistant_message_index,
85                ..
86            } => *assistant_message_index,
87        }
88    }
89
90    /// 估算该单元的 token 数(用 chars 计数 + /3,兼顾中文场景)
91    fn estimate_tokens(&self, messages: &[ChatMessage]) -> usize {
92        let total_chars: usize = match self {
93            MessageUnit::System { message_index }
94            | MessageUnit::User { message_index }
95            | MessageUnit::AssistantText { message_index } => {
96                messages[*message_index].content.chars().count()
97            }
98            MessageUnit::ToolGroup {
99                assistant_message_index,
100                tool_result_indices,
101            } => {
102                let mut chars = messages[*assistant_message_index].content.chars().count();
103                for &result_index in tool_result_indices {
104                    chars += messages[result_index].content.chars().count();
105                }
106                if let Some(ref tcs) = messages[*assistant_message_index].tool_calls {
107                    for tc in tcs {
108                        chars += tc.name.chars().count() + tc.arguments.chars().count();
109                    }
110                }
111                chars
112            }
113        };
114        total_chars / SIMPLE_CHARS_PER_TOKEN
115    }
116
117    /// ToolGroup 是否包含豁免工具(任一 tool_call 命中豁免清单即返回 true)
118    fn has_exempt_tool(&self, messages: &[ChatMessage], exempt_tools: &[String]) -> bool {
119        match self {
120            MessageUnit::ToolGroup {
121                assistant_message_index,
122                ..
123            } => messages[*assistant_message_index]
124                .tool_calls
125                .as_ref()
126                .map(|tcs| tcs.iter().any(|tc| is_exempt_tool(&tc.name, exempt_tools)))
127                .unwrap_or(false),
128            _ => false,
129        }
130    }
131}
132
133// ========== 解析 ==========
134
135/// 将消息序列解析为 MessageUnit 列表
136fn parse_message_units(messages: &[ChatMessage]) -> Vec<MessageUnit> {
137    let mut units = Vec::with_capacity(messages.len());
138    let mut i = 0;
139
140    while i < messages.len() {
141        let msg = &messages[i];
142
143        if msg.role == MessageRole::System {
144            units.push(MessageUnit::System { message_index: i });
145            i += 1;
146        } else if msg.role == MessageRole::User {
147            units.push(MessageUnit::User { message_index: i });
148            i += 1;
149        } else if msg.role == MessageRole::Assistant {
150            if msg.tool_calls.is_some() {
151                // assistant + tool_calls → 收集后续 tool result
152                let assistant_message_index = i;
153                let mut tool_result_indices = Vec::new(); // 大小未知,无法预分配
154                i += 1;
155                while i < messages.len() && messages[i].role == MessageRole::Tool {
156                    tool_result_indices.push(i);
157                    i += 1;
158                }
159                units.push(MessageUnit::ToolGroup {
160                    assistant_message_index,
161                    tool_result_indices,
162                });
163            } else {
164                // 纯文字 assistant 消息
165                units.push(MessageUnit::AssistantText { message_index: i });
166                i += 1;
167            }
168        } else if msg.role == MessageRole::Tool {
169            // 孤立的 tool result(没有前面的 assistant+tool_calls)
170            // 作为 ToolGroup 处理(只有 result 没有 assistant)
171            let start = i;
172            let mut tool_result_indices = vec![i];
173            i += 1;
174            while i < messages.len() && messages[i].role == MessageRole::Tool {
175                tool_result_indices.push(i);
176                i += 1;
177            }
178            // 孤立 tool results 最低优先级,作为 ToolGroup 处理
179            units.push(MessageUnit::ToolGroup {
180                assistant_message_index: start, // 没有真正的 assistant,用第一个 tool result 的索引
181                tool_result_indices,
182            });
183        } else {
184            // 未知角色,作为单条处理
185            units.push(MessageUnit::System { message_index: i });
186            i += 1;
187        }
188    }
189
190    units
191}
192
193// ========== 优先级选择 ==========
194
195/// 选择结果
196struct SelectionResult {
197    /// 保留的 unit 索引(在 units 中的位置)
198    retained: Vec<bool>,
199}
200
201/// 三阶段预算选择:时间保底 → 豁免保底 → 比例配额(+ 溢出)
202///
203/// select_units 的只读配置参数(units/messages 单独传)
204struct SelectUnitsParams<'a> {
205    max_history_messages: usize,
206    max_context_tokens: usize,
207    keep_recent: usize,
208    exempt_tools: &'a [String],
209}
210
211/// 三阶段选择逻辑:
212/// Stage 1: 保留最近 N 个 unit(时间保底)
213/// Stage 2: 保留豁免 ToolGroup(技能/任务保底)
214/// Stage 3: 按比例配额分配剩余预算
215fn select_units(
216    units: &[MessageUnit],
217    messages: &[ChatMessage],
218    params: &SelectUnitsParams,
219) -> SelectionResult {
220    // 解构出局部变量,保持函数体不变
221    let max_history_messages = params.max_history_messages;
222    let max_context_tokens = params.max_context_tokens;
223    let keep_recent = params.keep_recent;
224    let exempt_tools = params.exempt_tools;
225    let mut retained_flags = vec![false; units.len()];
226    let mut used_message_count = 0usize;
227    let mut used_token_count = 0usize;
228
229    // 记账 + 预算检查的闭包式辅助(rust 中改为函数避免借用问题)
230    let try_retain_unit = |message_index: usize,
231                           retained: &mut [bool],
232                           used_message_count: &mut usize,
233                           used_token_count: &mut usize|
234     -> bool {
235        if retained[message_index] {
236            return false;
237        }
238        let unit = &units[message_index];
239        let unit_msg_count = unit.msg_count();
240        let unit_tokens = unit.estimate_tokens(messages);
241        if *used_message_count + unit_msg_count > max_history_messages
242            || *used_token_count + unit_tokens > max_context_tokens
243        {
244            return false;
245        }
246        retained[message_index] = true;
247        *used_message_count += unit_msg_count;
248        *used_token_count += unit_tokens;
249        true
250    };
251
252    // ── System:始终保留(不计配额)──
253    for (i, unit) in units.iter().enumerate() {
254        if matches!(unit, MessageUnit::System { .. }) {
255            // System 即使超预算也保留(通常极少极短)
256            retained_flags[i] = true;
257            used_message_count += unit.msg_count();
258            used_token_count += unit.estimate_tokens(messages);
259        }
260    }
261
262    // ── Stage 1: 时间保底 ── 最近 K 个非 System unit 无条件保留
263    let recent_units_to_keep = keep_recent.saturating_mul(WINDOW_KEEP_RECENT_MULTIPLIER);
264    let mut stage1_retained_count = 0usize;
265    for i in (0..units.len()).rev() {
266        if stage1_retained_count >= recent_units_to_keep {
267            break;
268        }
269        if matches!(units[i], MessageUnit::System { .. }) {
270            continue;
271        }
272        if try_retain_unit(
273            i,
274            &mut retained_flags,
275            &mut used_message_count,
276            &mut used_token_count,
277        ) {
278            stage1_retained_count += 1;
279        } else {
280            // 预算耗尽即停(最新的装不下,更老的也别试了)
281            break;
282        }
283    }
284
285    // ── Stage 2: 豁免保底 ── 含豁免工具的 ToolGroup 按时间倒序保留
286    for i in (0..units.len()).rev() {
287        if retained_flags[i] {
288            continue;
289        }
290        if units[i].has_exempt_tool(messages, exempt_tools) {
291            try_retain_unit(
292                i,
293                &mut retained_flags,
294                &mut used_message_count,
295                &mut used_token_count,
296            );
297        }
298    }
299
300    // ── Stage 3: 比例配额 ── 剩余预算按比例分给三个 tier,tier 内按时间倒序
301    let remaining_msgs = max_history_messages.saturating_sub(used_message_count);
302    let remaining_toks = max_context_tokens.saturating_sub(used_token_count);
303
304    // tier 数值与 ContextTier::priority() 对齐:
305    // User=1, Assistant=3, RegularTool=4(KeyTool=2 已在 Stage 2 豁免保底)
306    let quotas: [(u8, f32); 3] = [
307        (ContextTier::User.priority(), WINDOW_QUOTA_USER),
308        (ContextTier::Assistant.priority(), WINDOW_QUOTA_ASST_TEXT),
309        (ContextTier::RegularTool.priority(), WINDOW_QUOTA_TOOL_GROUP),
310    ];
311
312    for (tier_prio, ratio) in quotas {
313        // tier 子预算(向下取整;溢出阶段会吸收未用完部分)
314        let tier_message_budget = ((remaining_msgs as f32) * ratio) as usize;
315        let tier_token_budget = ((remaining_toks as f32) * ratio) as usize;
316        let tier_start_msg_count = used_message_count;
317        let tier_start_token_count = used_token_count;
318
319        // 该 tier 未保留的 unit,按时间倒序
320        let mut tier_candidates: Vec<usize> = (0..units.len())
321            .filter(|&i| !retained_flags[i] && units[i].priority() == tier_prio)
322            .collect();
323        tier_candidates.sort_by(|&a, &b| units[b].first_idx().cmp(&units[a].first_idx()));
324
325        for idx in tier_candidates {
326            let unit = &units[idx];
327            let unit_msg_count = unit.msg_count();
328            let unit_tokens = unit.estimate_tokens(messages);
329            // 子预算 + 全局预算双检查
330            if used_message_count - tier_start_msg_count + unit_msg_count > tier_message_budget {
331                continue;
332            }
333            if used_token_count - tier_start_token_count + unit_tokens > tier_token_budget {
334                continue;
335            }
336            try_retain_unit(
337                idx,
338                &mut retained_flags,
339                &mut used_message_count,
340                &mut used_token_count,
341            );
342        }
343    }
344
345    // ── Stage 4: 溢出 ── 剩余预算按时间倒序贪心填充未保留 unit(任意 tier)
346    for i in (0..units.len()).rev() {
347        try_retain_unit(
348            i,
349            &mut retained_flags,
350            &mut used_message_count,
351            &mut used_token_count,
352        );
353    }
354
355    // ── 兜底 ── 至少保留最新 User unit
356    let has_user_retained = units
357        .iter()
358        .enumerate()
359        .any(|(i, u)| matches!(u, MessageUnit::User { .. }) && retained_flags[i]);
360    if !has_user_retained
361        && let Some(last_user_idx) = (0..units.len())
362            .rev()
363            .find(|&i| matches!(units[i], MessageUnit::User { .. }))
364    {
365        retained_flags[last_user_idx] = true;
366    }
367
368    SelectionResult {
369        retained: retained_flags,
370    }
371}
372
373// ========== 占位符替换 ==========
374
375/// 提取 ToolGroup 的工具名称列表(用于占位符)
376fn tool_names_of(unit: &MessageUnit, messages: &[ChatMessage]) -> Vec<String> {
377    match unit {
378        MessageUnit::ToolGroup {
379            assistant_message_index,
380            ..
381        } => messages[*assistant_message_index]
382            .tool_calls
383            .as_ref()
384            .map(|tcs| tcs.iter().map(|tc| tc.name.clone()).collect())
385            .unwrap_or_default(),
386        _ => Vec::new(),
387    }
388}
389
390/// 合并一批被丢弃的 ToolGroup 名称为单条占位符 assistant 消息
391/// 与 micro_compact 的占位符风格对齐:`[Previous: used X, Y, Z]`
392fn merged_placeholder(names: &[String]) -> ChatMessage {
393    let content = if names.is_empty() {
394        "[Previous tool calls dropped]".to_string()
395    } else {
396        format!("[Previous: used {}]", names.join(", "))
397    };
398    ChatMessage::text(MessageRole::Assistant, content)
399}
400
401// ========== 公开接口 ==========
402
403/// 优先级消息窗口选择(三阶段 + 比例配额 + 溢出 + 占位符合并)
404///
405/// # 参数
406/// - `messages`: 原始消息列表
407/// - `max_history_messages`: 消息条数上限(0 = 不限制)
408/// - `max_context_tokens_k`: token 预算上限,单位 K(0 = 不限制,100 = 100K tokens)
409/// - `keep_recent`: 与 `CompactConfig.keep_recent` 对齐;最近 `keep_recent * WINDOW_KEEP_RECENT_MULTIPLIER`
410///   个非 System unit 在 Stage 1 无条件保留
411/// - `exempt_tools`: 来自 `CompactConfig.micro_compact_exempt_tools`;含豁免工具的 ToolGroup
412///   在 Stage 2 优先保留,保护 skill/task 等承载关键上下文的调用
413pub fn select_messages(
414    messages: &[ChatMessage],
415    max_history_messages: usize,
416    max_context_tokens_k: usize,
417    keep_recent: usize,
418    exempt_tools: &[String],
419) -> Vec<ChatMessage> {
420    let max_msgs = if max_history_messages == 0 {
421        usize::MAX
422    } else {
423        max_history_messages
424    };
425    let max_tokens = if max_context_tokens_k == 0 {
426        usize::MAX
427    } else {
428        max_context_tokens_k * TOKEN_K_MULTIPLIER
429    };
430
431    let total_tokens = estimate_tokens_simple(messages);
432    if messages.len() <= max_msgs && total_tokens <= max_tokens {
433        return messages.to_vec();
434    }
435
436    let units = parse_message_units(messages);
437    let selection = select_units(
438        &units,
439        messages,
440        &SelectUnitsParams {
441            max_history_messages: max_msgs,
442            max_context_tokens: max_tokens,
443            keep_recent,
444            exempt_tools,
445        },
446    );
447
448    // 按原始顺序重组消息;被丢弃的相邻 ToolGroup 合并为单个占位符
449    let mut result = Vec::with_capacity(messages.len());
450    let mut pending_dropped_names: Vec<String> = Vec::new(); // 大小未知
451
452    let flush_pending = |pending: &mut Vec<String>, out: &mut Vec<ChatMessage>| {
453        if !pending.is_empty() {
454            out.push(merged_placeholder(pending));
455            pending.clear();
456        }
457    };
458
459    for (i, unit) in units.iter().enumerate() {
460        if selection.retained[i] {
461            flush_pending(&mut pending_dropped_names, &mut result);
462            match unit {
463                MessageUnit::System { message_index }
464                | MessageUnit::User { message_index }
465                | MessageUnit::AssistantText { message_index } => {
466                    result.push(messages[*message_index].clone());
467                }
468                MessageUnit::ToolGroup {
469                    assistant_message_index,
470                    tool_result_indices,
471                } => {
472                    result.push(messages[*assistant_message_index].clone());
473                    for &result_index in tool_result_indices {
474                        result.push(messages[result_index].clone());
475                    }
476                }
477            }
478        } else if matches!(unit, MessageUnit::ToolGroup { .. }) {
479            // 累积相邻被丢弃的 ToolGroup,后续一次性输出合并占位符
480            pending_dropped_names.extend(tool_names_of(unit, messages));
481        }
482        // User / AssistantText 丢弃时直接跳过(兜底保证最新 User 一定保留)
483    }
484    flush_pending(&mut pending_dropped_names, &mut result);
485
486    let dropped_count = selection.retained.iter().filter(|&&r| !r).count();
487    if dropped_count > 0 {
488        write_info_log(
489            "window_select",
490            &format!(
491                "三阶段窗口选择: 保留 {}/{} 单元, 丢弃 {} (tokens: {}→{}, keep_recent={})",
492                units.len() - dropped_count,
493                units.len(),
494                dropped_count,
495                total_tokens,
496                estimate_tokens_simple(&result),
497                keep_recent,
498            ),
499        );
500    }
501
502    result
503}
504
505/// 简易 token 估算(用于整体判断;与 MessageUnit::estimate_tokens 保持相同系数)
506fn estimate_tokens_simple(messages: &[ChatMessage]) -> usize {
507    let total_chars: usize = messages
508        .iter()
509        .map(|m| {
510            let mut chars = m.content.chars().count();
511            if let Some(ref tcs) = m.tool_calls {
512                for tc in tcs {
513                    chars += tc.name.chars().count() + tc.arguments.chars().count();
514                }
515            }
516            chars
517        })
518        .sum();
519    total_chars / 3
520}
521
522#[cfg(test)]
523mod tests;