Skip to main content

otherone_context/
lib.rs

1// 作用:otherone-context 模块 — 上下文管理(会话加载 + 压缩触发)
2// 关联:被 otherone-agent 的 invoke_agent 调用
3// 预期结果:提供 combine_context 核心方法
4
5pub mod compact;
6pub mod error;
7pub mod types;
8
9use tracing::warn;
10
11use crate::compact::check_threshold::check_threshold;
12use crate::compact::estimate_tokens::estimate_tokens;
13use crate::compact::compact_messages;
14use crate::error::ContextError;
15use crate::types::{CombineContextOptions, ContextLoadType};
16use otherone_ai::types::{Message, MessageContent};
17use otherone_storage::types::SessionData;
18
19/// 提取 MessageContent 中的文本
20fn message_content_text(content: &MessageContent) -> &str {
21    match content {
22        MessageContent::Text(t) => t.as_str(),
23        _ => "[非文本内容]",
24    }
25}
26
27/// 组合 context 配置
28/// 作用:根据 session_id 和 load_type 加载历史消息,检查 token 阈值并触发压缩
29/// 关联:被 agent loop 调用
30/// 预期结果:返回 messages 数组(包含 system prompt),供 AI 调用使用
31pub async fn combine_context(
32    options: &CombineContextOptions,
33) -> Result<Vec<Message>, ContextError> {
34    // 参数有效性检查
35    if options.session_id.is_empty() {
36        return Err(ContextError::ConfigError(
37            "session_id is required".to_string(),
38        ));
39    }
40    if options.context_window == 0 {
41        return Err(ContextError::ConfigError(
42            "context_window is required".to_string(),
43        ));
44    }
45
46    // 根据 load_type 加载数据
47    let session_data = match options.load_type {
48        ContextLoadType::Database => {
49            let db_config = options
50                .database_config
51                .as_ref()
52                .ok_or_else(|| {
53                    ContextError::ConfigError(
54                        "database_config is required when load_type is database".to_string(),
55                    )
56                })?;
57            otherone_storage::database::reader::read_session_data_from_database(
58                &options.session_id,
59                db_config,
60            )
61            .await
62            .map_err(|e| ContextError::StorageError(e.to_string()))?
63        }
64        ContextLoadType::LocalFile => {
65            otherone_storage::localfile::reader::read_session_data(&options.session_id)
66                .map_err(|e| ContextError::StorageError(e.to_string()))?
67        }
68    };
69
70    // 处理压缩记录,找到最新的压缩内容和起始 entry_id
71    let (compacted_summary, start_entry_id) =
72        process_compacted_entries(&session_data);
73
74    // 根据 start_entry_id 过滤 entries
75    let filtered_entries = filter_entries_from_start(&session_data.entries, &start_entry_id);
76
77    // 更新 session_data 的 entries 为过滤后的结果
78    let mut filtered_session = session_data;
79    filtered_session.entries = filtered_entries;
80
81    // 根据 provider 类型转换数据格式
82    let messages =
83        transform_to_messages(&filtered_session, &options.provider, compacted_summary.as_deref());
84
85    // 计算 token 用量并检查是否需要压缩
86    let mut context_tokens: u32 = 0;
87
88    // 获取最后一条 assistant 消息的 token_consumption
89    let assistant_entries: Vec<&otherone_storage::types::Entry> = filtered_session
90        .entries
91        .iter()
92        .rev()
93        .filter(|e| e.role == "assistant")
94        .collect();
95
96    // 最多检查 3 条 assistant 消息
97    let max_attempts = std::cmp::min(3, assistant_entries.len());
98    let mut found_token_index: Option<usize> = None;
99
100    for i in 0..max_attempts {
101        let entry = assistant_entries[i];
102        if let Some(tc) = entry.token_consumption {
103            context_tokens = tc;
104            if let Some(pos) = filtered_session.entries.iter().position(|e| e.entry_id == entry.entry_id) {
105                found_token_index = Some(pos);
106            }
107            break;
108        }
109    }
110
111    // 如果没有找到 token_consumption,或者找到后还有新消息,需要估算
112    if context_tokens == 0 || found_token_index.map_or(true, |idx| idx < filtered_session.entries.len() - 1) {
113        let estimated = estimate_tokens(&messages);
114
115        if context_tokens == 0 {
116            let mut extra_tokens: u32 = 0;
117            if let Some(ref sp) = options.system_prompt {
118                let sys_msg = Message {
119                    role: "system".to_string(),
120                    content: MessageContent::Text(sp.clone()),
121                    name: None,
122                    tool_calls: None,
123                    tool_call_id: None,
124                };
125                extra_tokens += estimate_tokens(&[sys_msg]);
126            }
127            if let Some(ref tools) = options.tools {
128                let tools_json = serde_json::to_string(tools).unwrap_or_default();
129                let tools_msg = Message {
130                    role: "system".to_string(),
131                    content: MessageContent::Text(tools_json),
132                    name: None,
133                    tool_calls: None,
134                    tool_call_id: None,
135                };
136                extra_tokens += estimate_tokens(&[tools_msg]);
137            }
138            context_tokens = estimated + extra_tokens;
139        } else {
140            context_tokens += estimated;
141        }
142    }
143
144    // 检查是否需要压缩
145    let should_compress = check_threshold(
146        context_tokens,
147        options.context_window,
148        options.threshold_percentage,
149    );
150
151    // 构建最终的 messages
152    let mut final_messages = messages;
153
154    if should_compress {
155        let has_compacted = final_messages.first().map_or(false, |m| {
156            m.role == "user"
157                && (message_content_text(&m.content).contains("[压缩")
158                    || message_content_text(&m.content).contains("[Compressed"))
159        });
160
161        warn!(
162            "Context threshold exceeded ({} tokens / {} window). Triggering compaction.",
163            context_tokens, options.context_window
164        );
165
166        // 实际执行压缩
167        let storage_type = match options.load_type {
168            ContextLoadType::Database => otherone_storage::types::StorageType::Database,
169            ContextLoadType::LocalFile => otherone_storage::types::StorageType::LocalFile,
170        };
171
172        match compact_messages(
173            &final_messages,
174            context_tokens,
175            options.context_window,
176            None,
177            options.ai.as_ref(),
178            has_compacted,
179            Some(&options.session_id),
180            Some(&storage_type),
181            options.database_config.as_ref(),
182            Some(&filtered_session.entries),
183        ).await {
184            Ok(compacted) => {
185                final_messages = compacted;
186            }
187            Err(e) => {
188                warn!("Compaction failed: {}. Continuing with uncompacted messages.", e);
189            }
190        }
191    }
192
193    // 在返回前添加 system prompt 到最前面
194    if let Some(ref system_prompt) = options.system_prompt {
195        final_messages.insert(0, Message {
196            role: "system".to_string(),
197            content: MessageContent::Text(system_prompt.clone()),
198            name: None,
199            tool_calls: None,
200            tool_call_id: None,
201        });
202    }
203
204    Ok(final_messages)
205}
206
207/// 处理压缩记录,找到最新的压缩内容和起始 entry_id
208fn process_compacted_entries(session_data: &SessionData) -> (Option<String>, Option<String>) {
209    if session_data.compacted_entries.is_empty() {
210        return (None, None);
211    }
212
213    let mut sorted: Vec<&otherone_storage::types::CompactedEntry> =
214        session_data.compacted_entries.iter().collect();
215    sorted.sort_by(|a, b| b.create_at.cmp(&a.create_at));
216
217    let latest = sorted[0];
218
219    if latest.summary.is_empty() {
220        return (None, None);
221    }
222
223    let final_entry_id = find_final_trigger_entry_id(
224        &latest.trigger_entry_id,
225        &session_data.compacted_entries,
226        &session_data.entries,
227    );
228
229    (Some(latest.summary.clone()), final_entry_id)
230}
231
232/// 递归查找最终的 trigger_entry_id
233fn find_final_trigger_entry_id(
234    trigger_id: &str,
235    compacted_entries: &[otherone_storage::types::CompactedEntry],
236    entries: &[otherone_storage::types::Entry],
237) -> Option<String> {
238    let max_depth = 10;
239    let mut current_id = trigger_id.to_string();
240    let mut visited_ids = std::collections::HashSet::new();
241
242    for _ in 0..max_depth {
243        if visited_ids.contains(&current_id) {
244            warn!("检测到循环引用: {}", current_id);
245            return None;
246        }
247        visited_ids.insert(current_id.clone());
248
249        if entries.iter().any(|e| e.entry_id == current_id) {
250            return Some(current_id);
251        }
252
253        if let Some(compacted) = compacted_entries
254            .iter()
255            .find(|c| c.entry_id == current_id)
256        {
257            current_id = compacted.trigger_entry_id.clone();
258        } else {
259            warn!("trigger_entry_id 不存在: {}", current_id);
260            return None;
261        }
262    }
263
264    warn!("超过最大递归深度,最后的 ID: {}", current_id);
265    None
266}
267
268/// 从 start_entry_id 开始过滤 entries
269fn filter_entries_from_start(
270    entries: &[otherone_storage::types::Entry],
271    start_entry_id: &Option<String>,
272) -> Vec<otherone_storage::types::Entry> {
273    let start_id = match start_entry_id {
274        Some(id) => id,
275        None => return entries.to_vec(),
276    };
277
278    if let Some(start_index) = entries.iter().position(|e| e.entry_id == *start_id) {
279        entries[start_index..].to_vec()
280    } else {
281        warn!("start_entry_id 不存在于 entries 中: {}", start_id);
282        entries.to_vec()
283    }
284}
285
286/// 将 session 数据转换为 AI 请求的 messages 格式
287fn transform_to_messages(
288    session_data: &SessionData,
289    provider: &otherone_ai::types::ProviderType,
290    compacted_summary: Option<&str>,
291) -> Vec<Message> {
292    match provider {
293        otherone_ai::types::ProviderType::OpenAI
294        | otherone_ai::types::ProviderType::OpenRouter
295        | otherone_ai::types::ProviderType::Local
296        | otherone_ai::types::ProviderType::Fetch => {
297            transform_to_openai_format(session_data, compacted_summary)
298        }
299        otherone_ai::types::ProviderType::Anthropic => {
300            transform_to_anthropic_format(session_data, compacted_summary)
301        }
302    }
303}
304
305/// 转换为 OpenAI 格式的 messages
306fn transform_to_openai_format(
307    session_data: &SessionData,
308    compacted_summary: Option<&str>,
309) -> Vec<Message> {
310    let mut messages: Vec<Message> = Vec::new();
311
312    if let Some(summary) = compacted_summary {
313        messages.push(Message {
314            role: "user".to_string(),
315            content: MessageContent::Text(summary.to_string()),
316            name: None,
317            tool_calls: None,
318            tool_call_id: None,
319        });
320    }
321
322    if session_data.entries.is_empty() {
323        return messages;
324    }
325
326    for entry in &session_data.entries {
327        let mut message = Message {
328            role: entry.role.clone(),
329            content: MessageContent::Text(entry.content.clone()),
330            name: None,
331            tool_calls: None,
332            tool_call_id: None,
333        };
334
335        if entry.role == "assistant" {
336            if let Some(ref tools) = entry.tools {
337                if let Some(tool_calls) = tools.get("tool_calls") {
338                    if let Ok(parsed_tool_calls) =
339                        serde_json::from_value(tool_calls.clone())
340                    {
341                        message.tool_calls = Some(parsed_tool_calls);
342                    }
343                }
344            }
345        }
346
347        if entry.role == "tool" {
348            if let Some(ref tools) = entry.tools {
349                message.tool_call_id = tools
350                    .get("tool_call_id")
351                    .and_then(|v| v.as_str())
352                    .map(|s| s.to_string());
353                message.name = tools
354                    .get("function_name")
355                    .and_then(|v| v.as_str())
356                    .map(|s| s.to_string());
357            }
358        }
359
360        messages.push(message);
361    }
362
363    messages
364}
365
366/// 转换为 Anthropic 格式的 messages
367fn transform_to_anthropic_format(
368    session_data: &SessionData,
369    compacted_summary: Option<&str>,
370) -> Vec<Message> {
371    transform_to_openai_format(session_data, compacted_summary)
372}