Skip to main content

atomcode_core/ctx/
ollama.rs

1//! [`OllamaCtx`] — 为本地小窗口 Ollama 模型优化的上下文策略。
2//!
3//! ## 与 [`DefaultCtx`] 的三点差异
4//!
5//! 1. **更早触发压缩**:总 tokens 超过窗口 35% 就压,而非默认的 50%。
6//!    8K 窗口下 ~2800 tokens 即启动压缩,给后续 turn 留呼吸空间。
7//! 2. **工具输出更紧**:单条 tool_result 上限 = `ctx/8` clamp `[2K, 6K]`
8//!    字节,显著低于 Default 的 `[8K, 32K]`。本地模型 8K 窗口下一条
9//!    bash 输出占一半预算是主要失败模式。
10//! 3. **窗口默认值降低**:若 `provider.context_window` 未设,fallback
11//!    到 8000(Default 是 128000)。匹配 Ollama CLI 的 `num_ctx` 常见值。
12//!
13//! ## 不做的事(明确范围)
14//!
15//! - **不砍 system prompt**:tool schema 作为独立参数传给 LLM API,
16//!   不在 `system_prompt: &str` 里。真要简化 system prompt 需要在
17//!   [`crate::agent::prompt`] 层面做。
18//! - **不改工具集筛选**:哪些工具暴露给模型是 [`crate::tool::ToolRegistry`]
19//!   的职责,与 ctx 无关。
20//! - **不重写 render/microcompact/replace_stale_reads**:`build_messages`
21//!   直接透传给 [`crate::ctx::render::build_messages`] —— 与默认行为同
22//!   pipeline,只是 ctx_window 更小、配合更紧的 tool-output 截断。
23//!   想要 render pipeline 级别的定制,完全重写自己的 `build_messages`
24//!   即可,不必受这里影响。
25//!
26//! 需要以上行为时,在上层扩展相应模块,不在 ctx 里做。
27
28use super::CtxBuilder;
29use crate::config::provider::ProviderConfig;
30use crate::conversation::message::Message;
31use crate::conversation::{ContextStats, Conversation};
32use crate::tool::ToolResult;
33
34/// 本地 Ollama 模型的上下文策略。
35#[derive(Debug, Clone)]
36pub struct OllamaCtx {
37    /// Token budget, 至少 4K(再低就没意义了)
38    ctx_window: usize,
39
40    /// Lowercased model id。用于 [`crate::ctx::render::apply_model_directives`]
41    /// 判断是否追加 CJK 语言锁 / MiniMax thinking 纪律。本地 Ollama 也
42    /// 常跑 qwen / deepseek / minimax 蒸馏版,同一套规则适用。
43    model_id: String,
44}
45
46impl OllamaCtx {
47    pub fn new(provider: &ProviderConfig) -> Self {
48        // Ollama 的默认 ctx 是 8000(见 default_context_window_for).
49        // 再给一个硬下限防 0 / 配置漂移。
50        Self {
51            ctx_window: provider.context_window.max(4000),
52            model_id: provider.model.to_lowercase(),
53        }
54    }
55
56    /// 单条 tool_result 硬字符上限: ctx/8 clamp [2K, 6K].
57    /// 对比 Default 的 ctx/8 clamp [8K, 32K] 显著更紧。
58    fn tool_output_cap(&self) -> usize {
59        (self.ctx_window / 8).min(6_000).max(2_000)
60    }
61}
62
63impl CtxBuilder for OllamaCtx {
64    fn build_messages(
65        &self,
66        conv: &Conversation,
67        system_prompt: &str,
68        turn_reminder: &str,
69    ) -> (Vec<Message>, ContextStats) {
70        // 渲染透传给默认 render 管道,仅把 ctx_window 传下去决定
71        // token 预算; cold zone / microcompact / hard-cut / turn_reminder
72        // 注入的具体策略由 ctx::render 统一执行。
73        // model_id 依赖的指令(CJK 语言锁 / MiniMax thinking 纪律)
74        // 在渲染管道前贴到 system prompt 上,与 DefaultCtx 一致。
75        let sys = crate::ctx::render::apply_model_directives(system_prompt, &self.model_id);
76        crate::ctx::render::build_messages(conv, &sys, self.ctx_window, turn_reminder)
77    }
78
79    /// 复用 ctx::render::needs_compression — 它的绝对 headroom 公式
80    /// `ctx_window - min(13K, ctx_window/4)` 在小窗口下天然偏紧
81    /// (8K Ollama → 6K threshold = 75% 触发, 比之前的 35% 晚但更接近"撑爆前一刻"的真实 headroom)。
82    /// 之前的 35% hardcoded 阈值是为 4-8K Ollama 量身的早触发, 但在
83    /// 16K-32K Ollama 上反而过早。新公式自适应窗口大小, 不再需要单独的 Ollama tier。
84    fn needs_compression(&self, conv: &Conversation, system_tokens: usize) -> bool {
85        crate::ctx::render::needs_compression(conv, system_tokens, self.ctx_window)
86    }
87
88    fn compression_plan(&self, conv: &Conversation) -> Option<(String, usize)> {
89        // 决策用的是 self.needs_compression(35% 早触发),
90        // plan 内容生成沿用 ctx::render 的 one-line-per-round 机械摘要。
91        let (content, n) = crate::ctx::render::build_compression_content(conv);
92        if content.is_empty() || n == 0 {
93            None
94        } else {
95            Some((content, n))
96        }
97    }
98
99    fn truncate_tool_output(&self, result: &mut ToolResult, tool_name: &str) {
100        // 先走共享的 per-tool 截断(bash 保错误行、read_file 出 skeleton、
101        // web_fetch head+tail 等)。传入 self.ctx_window 让内部公式知道
102        // 窗口小。
103        crate::ctx::truncate::truncate_output(result, tool_name, self.ctx_window);
104
105        // 再套 Ollama tier 的硬上限,belt-and-suspenders。
106        let cap = self.tool_output_cap();
107        if result.output.len() > cap {
108            // UTF-8 安全截断:cap 可能落在 multi-byte char 中间,
109            // 走到前一个 char boundary 再切。
110            let mut boundary = cap;
111            while boundary > 0 && !result.output.is_char_boundary(boundary) {
112                boundary -= 1;
113            }
114            result.output.truncate(boundary);
115            result
116                .output
117                .push_str("\n[... truncated by OllamaCtx (small window) ...]");
118        }
119    }
120
121    fn ctx_window(&self) -> usize {
122        self.ctx_window
123    }
124
125    fn name(&self) -> &'static str {
126        "ollama"
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::conversation::Conversation;
134    use crate::tool::ToolResult;
135
136    fn ollama_provider(ctx: usize) -> ProviderConfig {
137        ProviderConfig {
138            provider_type: "ollama".into(),
139            api_key: None,
140            model: "llama3-8b".into(),
141            base_url: Some("http://localhost:11434".into()),
142            system_prompt: None,
143            user_agent: None,
144            context_window: ctx,
145            max_tokens: None,
146            thinking_type: None,
147            thinking_keep: None,
148            reasoning_history: None,
149            thinking_enabled: None,
150            thinking_budget: None,
151            skip_tls_verify: false,
152            ephemeral: false,
153
154}
155    }
156
157    #[test]
158    fn name_is_ollama() {
159        let o = OllamaCtx::new(&ollama_provider(8_000));
160        assert_eq!(o.name(), "ollama");
161    }
162
163    #[test]
164    fn ctx_window_clamped_to_4k_minimum() {
165        // 防御:context_window = 0 或缺失时 fallback 到 4000
166        let o = OllamaCtx::new(&ollama_provider(0));
167        assert_eq!(o.ctx_window, 4_000);
168
169        let o = OllamaCtx::new(&ollama_provider(2_000));
170        assert_eq!(o.ctx_window, 4_000);
171
172        // 正常值不变
173        let o = OllamaCtx::new(&ollama_provider(8_000));
174        assert_eq!(o.ctx_window, 8_000);
175
176        let o = OllamaCtx::new(&ollama_provider(32_000));
177        assert_eq!(o.ctx_window, 32_000);
178    }
179
180    #[test]
181    fn tool_output_cap_follows_spec() {
182        // ctx=8K → 8000/8=1000, 被 max(2000) 抬到 2000
183        assert_eq!(
184            OllamaCtx::new(&ollama_provider(8_000)).tool_output_cap(),
185            2_000
186        );
187        // ctx=16K → 16000/8=2000, 正好等于下限
188        assert_eq!(
189            OllamaCtx::new(&ollama_provider(16_000)).tool_output_cap(),
190            2_000
191        );
192        // ctx=32K → 32000/8=4000, 在 [2K, 6K] 内
193        assert_eq!(
194            OllamaCtx::new(&ollama_provider(32_000)).tool_output_cap(),
195            4_000
196        );
197        // ctx=64K → 64000/8=8000, 被 min(6000) 压到 6000
198        assert_eq!(
199            OllamaCtx::new(&ollama_provider(64_000)).tool_output_cap(),
200            6_000
201        );
202    }
203
204    #[test]
205    fn truncate_result_enforces_small_cap() {
206        let o = OllamaCtx::new(&ollama_provider(8_000));
207        let mut r = ToolResult {
208            call_id: "t1".into(),
209            output: "x".repeat(50_000),
210            success: true,
211        };
212        o.truncate_tool_output(&mut r, "bash");
213        // tool_output_cap() = 2000, 加上后缀消息大约 +50 字节
214        assert!(
215            r.output.len() <= 2_200,
216            "OllamaCtx truncate 后输出 {} 字节超过 cap 2200",
217            r.output.len()
218        );
219    }
220
221    #[test]
222    fn truncate_result_utf8_safe_on_cjk_boundary() {
223        // 回归:3 字节 CJK 字符重复,裁切点可能落在 char 中间,
224        // String::truncate 本身会 panic。is_char_boundary 循环修正。
225        let o = OllamaCtx::new(&ollama_provider(8_000));
226        let mut r = ToolResult {
227            call_id: "t1".into(),
228            output: "中".repeat(5_000), // 15000 字节,远超 2K cap
229            success: true,
230        };
231        o.truncate_tool_output(&mut r, "bash");
232        // 不 panic + 输出仍是合法 UTF-8
233        assert!(std::str::from_utf8(r.output.as_bytes()).is_ok());
234        assert!(r.output.len() <= 2_200);
235    }
236
237    #[test]
238    fn needs_compression_triggers_earlier_than_default() {
239        let o = OllamaCtx::new(&ollama_provider(8_000));
240        // 空对话不触发
241        let empty = Conversation::new();
242        assert!(!o.needs_compression(&empty, 100));
243
244        // 构造超过 35% 阈值(= 2800 tokens)的对话,模型数也够(>= 12)
245        let mut conv = Conversation::new();
246        for i in 0..8 {
247            conv.add_user_message(&format!("user turn {} with moderate content", i));
248            conv.add_assistant_tool_calls(
249                Some(&format!("some assistant reasoning for turn {}", i)),
250                vec![],
251                None,
252            );
253        }
254        // 16 条消息,每条 ~10-15 tokens → 总 ~200 tokens,低于 35%,不压
255        assert!(!o.needs_compression(&conv, 50));
256
257        // 再填大量长消息让总 tokens 超过 2800
258        for _ in 0..20 {
259            conv.add_user_message(&"lorem ipsum ".repeat(50).repeat(2)); // 每条 ~250 tokens
260            conv.add_assistant_tool_calls(Some(&"dolor sit amet ".repeat(50)), vec![], None);
261        }
262        // 此时总 tokens 远超 2800
263        assert!(
264            o.needs_compression(&conv, 50),
265            "大对话下 OllamaCtx 应触发压缩(35% threshold)"
266        );
267    }
268
269    #[test]
270    fn compression_plan_none_below_threshold() {
271        let o = OllamaCtx::new(&ollama_provider(8_000));
272        let conv = Conversation::new();
273        assert!(o.compression_plan(&conv).is_none());
274    }
275
276    #[test]
277    fn build_messages_returns_nonempty_for_simple_conv() {
278        let o = OllamaCtx::new(&ollama_provider(8_000));
279        let mut conv = Conversation::new();
280        conv.add_user_message("hello");
281        let (msgs, stats) = o.build_messages(&conv, "SYS", "");
282        assert!(!msgs.is_empty());
283        assert!(stats.sent_tokens <= 8_000);
284    }
285}