Skip to main content

aster/providers/
codex_stateful.rs

1//! Codex 有状态 Provider 实现
2//!
3//! 该模块使用 Codex app-server 协议实现有状态的会话管理,
4//! 支持上下文连贯的多轮对话。
5//!
6//! 与原有的 codex.rs (exec 模式) 不同,该实现:
7//! - 维护长驻的 app-server 进程
8//! - 使用 thread/turn 机制保持会话状态
9//! - 支持会话恢复 (thread/resume)
10
11use anyhow::Result;
12use async_trait::async_trait;
13use once_cell::sync::Lazy;
14use serde_json::json;
15use std::ffi::OsString;
16use std::path::{Path, PathBuf};
17use std::sync::Mutex;
18
19use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
20use super::codex::{CODEX_DEFAULT_MODEL, CODEX_DOC_URL, CODEX_KNOWN_MODELS};
21use super::codex_app_server::{AppServerEvent, CodexSessionManager};
22use super::errors::ProviderError;
23use super::utils::RequestLog;
24use crate::config::base::{CodexCommand, CodexReasoningEffort, CodexUseAppServer};
25use crate::config::search_path::SearchPaths;
26use crate::config::Config;
27use crate::conversation::message::{Message, MessageContent};
28use crate::model::ModelConfig;
29use rmcp::model::Role;
30use rmcp::model::Tool;
31
32/// 全局会话管理器
33static SESSION_MANAGER: Lazy<Mutex<Option<CodexSessionManager>>> = Lazy::new(|| Mutex::new(None));
34
35/// 获取或初始化会话管理器
36fn get_session_manager(command: &Path) -> Result<(), ProviderError> {
37    let mut manager = SESSION_MANAGER
38        .lock()
39        .map_err(|e| ProviderError::RequestFailed(format!("获取会话管理器锁失败: {}", e)))?;
40
41    if manager.is_none() {
42        *manager = Some(CodexSessionManager::new(command.to_path_buf()));
43    }
44
45    Ok(())
46}
47
48/// Codex 有状态 Provider
49#[derive(Debug)]
50pub struct CodexStatefulProvider {
51    command: PathBuf,
52    model: ModelConfig,
53    name: String,
54    reasoning_effort: String,
55}
56
57impl CodexStatefulProvider {
58    /// 从环境创建 Provider
59    pub async fn from_env(model: ModelConfig) -> Result<Self> {
60        let config = Config::global();
61        let command: OsString = config.get_codex_command().unwrap_or_default().into();
62        let resolved_command = SearchPaths::builder().with_npm().resolve(command)?;
63
64        let reasoning_effort = config
65            .get_codex_reasoning_effort()
66            .map(|r| r.to_string())
67            .unwrap_or_else(|_| "high".to_string());
68
69        Ok(Self {
70            command: resolved_command,
71            model,
72            name: "codex-stateful".to_string(),
73            reasoning_effort,
74        })
75    }
76
77    /// 检查是否应该使用 app-server 模式
78    pub fn should_use_app_server() -> bool {
79        let config = Config::global();
80        config
81            .get_codex_use_app_server()
82            .map(|s| s.to_lowercase() == "true")
83            .unwrap_or(true)
84    }
85
86    /// 将消息转换为用户输入文本
87    fn messages_to_input(&self, system: &str, messages: &[Message]) -> String {
88        let mut input = String::new();
89
90        // 添加系统提示(如果有)
91        if !system.is_empty() {
92            input.push_str("[System Instructions]\n");
93            input.push_str(system);
94            input.push_str("\n\n");
95        }
96
97        // 只取最后一条用户消息作为当前输入
98        // 历史消息由 app-server 的 thread 机制维护
99        if let Some(last_user_msg) = messages.iter().rev().find(|m| m.role == Role::User) {
100            for content in &last_user_msg.content {
101                if let MessageContent::Text(text_content) = content {
102                    input.push_str(&text_content.text);
103                }
104            }
105        }
106
107        input
108    }
109
110    /// 生成会话 ID(基于消息内容的哈希)
111    fn generate_conversation_id(&self, messages: &[Message]) -> String {
112        use std::collections::hash_map::DefaultHasher;
113        use std::hash::{Hash, Hasher};
114
115        let mut hasher = DefaultHasher::new();
116
117        // 使用第一条用户消息作为会话标识
118        if let Some(first_user_msg) = messages.iter().find(|m| m.role == Role::User) {
119            for content in &first_user_msg.content {
120                if let MessageContent::Text(text_content) = content {
121                    text_content.text.hash(&mut hasher);
122                    break;
123                }
124            }
125        }
126
127        format!("conv_{:x}", hasher.finish())
128    }
129
130    /// 使用 app-server 执行请求
131    fn execute_with_app_server(
132        &self,
133        system: &str,
134        messages: &[Message],
135    ) -> Result<(String, Usage), ProviderError> {
136        // 初始化会话管理器
137        get_session_manager(&self.command)?;
138
139        let conversation_id = self.generate_conversation_id(messages);
140        let input = self.messages_to_input(system, messages);
141
142        // 获取当前工作目录
143        let cwd = std::env::current_dir()
144            .ok()
145            .map(|p| p.to_string_lossy().to_string());
146
147        // 获取或创建连接
148        {
149            let manager = SESSION_MANAGER.lock().map_err(|e| {
150                ProviderError::RequestFailed(format!("获取会话管理器锁失败: {}", e))
151            })?;
152
153            if let Some(mgr) = manager.as_ref() {
154                mgr.get_or_create_connection(
155                    &conversation_id,
156                    cwd.as_deref(),
157                    Some(&self.model.model_name),
158                )?;
159            }
160        }
161
162        // 发送消息
163        let (response_text, events) = {
164            let manager = SESSION_MANAGER.lock().map_err(|e| {
165                ProviderError::RequestFailed(format!("获取会话管理器锁失败: {}", e))
166            })?;
167
168            if let Some(mgr) = manager.as_ref() {
169                mgr.send_message(
170                    &conversation_id,
171                    &input,
172                    Some(&self.model.model_name),
173                    Some(&self.reasoning_effort),
174                )?
175            } else {
176                return Err(ProviderError::RequestFailed(
177                    "会话管理器未初始化".to_string(),
178                ));
179            }
180        };
181
182        // 从事件中提取 usage 信息
183        let usage = self.extract_usage_from_events(&events);
184
185        if std::env::var("ASTER_CODEX_DEBUG").is_ok() {
186            println!("=== CODEX STATEFUL DEBUG ===");
187            println!("Conversation ID: {}", conversation_id);
188            println!("Input: {}", input);
189            println!("Response: {}", response_text);
190            println!("Events count: {}", events.len());
191            println!("============================");
192        }
193
194        Ok((response_text, usage))
195    }
196
197    /// 从事件中提取 usage 信息
198    fn extract_usage_from_events(&self, _events: &[AppServerEvent]) -> Usage {
199        // TODO: 从 turn/completed 事件中提取 token 使用量
200        // 目前 app-server 协议的 usage 信息可能在 turn/completed 的 params 中
201        Usage::default()
202    }
203
204    /// 生成简单的会话描述
205    fn generate_simple_session_description(
206        &self,
207        messages: &[Message],
208    ) -> Result<(Message, ProviderUsage), ProviderError> {
209        let description = messages
210            .iter()
211            .find(|m| m.role == Role::User)
212            .and_then(|m| {
213                m.content.iter().find_map(|c| match c {
214                    MessageContent::Text(text_content) => Some(&text_content.text),
215                    _ => None,
216                })
217            })
218            .map(|text| {
219                text.split_whitespace()
220                    .take(4)
221                    .collect::<Vec<_>>()
222                    .join(" ")
223            })
224            .unwrap_or_else(|| "Simple task".to_string());
225
226        let message = Message::new(
227            Role::Assistant,
228            chrono::Utc::now().timestamp(),
229            vec![MessageContent::text(description)],
230        );
231
232        Ok((
233            message,
234            ProviderUsage::new(self.model.model_name.clone(), Usage::default()),
235        ))
236    }
237}
238
239#[async_trait]
240impl Provider for CodexStatefulProvider {
241    fn metadata() -> ProviderMetadata {
242        ProviderMetadata::new(
243            "codex-stateful",
244            "OpenAI Codex CLI (Stateful)",
245            "使用 app-server 协议的有状态 Codex Provider,支持会话持久化和上下文连贯。",
246            CODEX_DEFAULT_MODEL,
247            CODEX_KNOWN_MODELS.to_vec(),
248            CODEX_DOC_URL,
249            vec![
250                ConfigKey::from_value_type::<CodexCommand>(true, false),
251                ConfigKey::from_value_type::<CodexReasoningEffort>(false, false),
252                ConfigKey::from_value_type::<CodexUseAppServer>(false, false),
253            ],
254        )
255    }
256
257    fn get_name(&self) -> &str {
258        &self.name
259    }
260
261    fn get_model_config(&self) -> ModelConfig {
262        self.model.clone()
263    }
264
265    #[tracing::instrument(
266        skip(self, model_config, system, messages, _tools),
267        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
268    )]
269    async fn complete_with_model(
270        &self,
271        model_config: &ModelConfig,
272        system: &str,
273        messages: &[Message],
274        _tools: &[Tool],
275    ) -> Result<(Message, ProviderUsage), ProviderError> {
276        // 会话描述请求使用简单方式
277        if system.contains("four words or less") || system.contains("4 words or less") {
278            return self.generate_simple_session_description(messages);
279        }
280
281        // 使用 app-server 执行
282        let (response_text, usage) = self.execute_with_app_server(system, messages)?;
283
284        if response_text.is_empty() {
285            return Err(ProviderError::RequestFailed(
286                "Codex app-server 返回空响应".to_string(),
287            ));
288        }
289
290        let message = Message::new(
291            Role::Assistant,
292            chrono::Utc::now().timestamp(),
293            vec![MessageContent::text(response_text)],
294        );
295
296        // 记录请求日志
297        let payload = json!({
298            "command": self.command,
299            "model": model_config.model_name,
300            "reasoning_effort": self.reasoning_effort,
301            "mode": "app-server",
302            "messages_count": messages.len()
303        });
304
305        let mut log = RequestLog::start(model_config, &payload)
306            .map_err(|e| ProviderError::RequestFailed(format!("记录请求日志失败: {}", e)))?;
307
308        let response = json!({
309            "usage": usage
310        });
311
312        log.write(&response, Some(&usage))
313            .map_err(|e| ProviderError::RequestFailed(format!("写入请求日志失败: {}", e)))?;
314
315        Ok((
316            message,
317            ProviderUsage::new(model_config.model_name.clone(), usage),
318        ))
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_metadata() {
328        let metadata = CodexStatefulProvider::metadata();
329        assert_eq!(metadata.name, "codex-stateful");
330        assert!(!metadata.known_models.is_empty());
331    }
332
333    #[test]
334    fn test_should_use_app_server_default() {
335        // 默认应该使用 app-server
336        // 注意:这个测试可能受环境变量影响
337        let _result = CodexStatefulProvider::should_use_app_server();
338        // 测试只验证函数能正常调用,结果依赖环境变量
339    }
340}