Skip to main content

aster/providers/
codex_app_server.rs

1//! Codex app-server 协议实现
2//!
3//! 该模块实现了与 Codex CLI 的 app-server 模式通信,
4//! 支持会话持久化和上下文连贯。
5//!
6//! 协议基于 JSON-RPC 2.0 over stdio,主要方法:
7//! - initialize: 初始化连接
8//! - thread/start: 创建新会话
9//! - thread/resume: 恢复已有会话
10//! - turn/start: 发送用户消息
11//! - turn/interrupt: 中断当前回合
12
13use serde::{Deserialize, Serialize};
14use serde_json::{json, Value};
15use std::collections::HashMap;
16use std::io::{BufRead, BufReader, Write};
17use std::path::PathBuf;
18use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::{Arc, Mutex};
21
22use super::errors::ProviderError;
23
24/// JSON-RPC 请求 ID 生成器
25static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
26
27fn next_request_id() -> u64 {
28    REQUEST_ID.fetch_add(1, Ordering::SeqCst)
29}
30
31/// Thread 信息
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ThreadInfo {
34    pub id: String,
35    pub preview: Option<String>,
36    #[serde(rename = "modelProvider")]
37    pub model_provider: Option<String>,
38    #[serde(rename = "createdAt")]
39    pub created_at: Option<i64>,
40}
41
42/// Turn 信息
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TurnInfo {
45    pub id: String,
46    pub status: String,
47    pub items: Vec<TurnItem>,
48    pub error: Option<String>,
49}
50
51/// Turn 中的 Item
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type")]
54pub enum TurnItem {
55    #[serde(rename = "agentMessage")]
56    AgentMessage {
57        id: String,
58        text: Option<String>,
59        #[serde(default)]
60        complete: bool,
61    },
62    #[serde(rename = "reasoning")]
63    Reasoning {
64        id: String,
65        text: Option<String>,
66        #[serde(default)]
67        complete: bool,
68    },
69    #[serde(rename = "toolCall")]
70    ToolCall {
71        id: String,
72        name: Option<String>,
73        #[serde(default)]
74        complete: bool,
75    },
76    #[serde(other)]
77    Unknown,
78}
79
80/// app-server 事件类型
81#[derive(Debug, Clone)]
82pub enum AppServerEvent {
83    /// 线程已启动
84    ThreadStarted(ThreadInfo),
85    /// Turn 已启动
86    TurnStarted(TurnInfo),
87    /// Item 开始
88    ItemStarted { item_id: String, item_type: String },
89    /// Agent 消息增量
90    AgentMessageDelta { item_id: String, text: String },
91    /// Reasoning 增量
92    ReasoningDelta { item_id: String, text: String },
93    /// Item 完成
94    ItemCompleted { item_id: String },
95    /// Turn 完成
96    TurnCompleted(TurnInfo),
97    /// 错误
98    Error(String),
99    /// 未知事件
100    Unknown(Value),
101}
102
103/// Codex app-server 连接管理器
104pub struct CodexAppServerConnection {
105    /// 子进程
106    child: Child,
107    /// stdin 写入器
108    stdin: ChildStdin,
109    /// stdout 读取器
110    stdout_reader: BufReader<ChildStdout>,
111    /// 当前 thread ID
112    current_thread_id: Option<String>,
113    /// 待处理的响应
114    pending_responses: HashMap<u64, tokio::sync::oneshot::Sender<Result<Value, ProviderError>>>,
115}
116
117impl CodexAppServerConnection {
118    /// 启动 app-server 进程
119    pub fn spawn(command: &PathBuf, cwd: Option<&str>) -> Result<Self, ProviderError> {
120        let mut cmd = Command::new(command);
121        cmd.arg("app-server")
122            .stdin(Stdio::piped())
123            .stdout(Stdio::piped())
124            .stderr(Stdio::piped());
125
126        if let Some(dir) = cwd {
127            cmd.current_dir(dir);
128        }
129
130        let mut child = cmd.spawn().map_err(|e| {
131            ProviderError::RequestFailed(format!(
132                "无法启动 Codex app-server: {}. 请确保已安装 Codex CLI (npm i -g @openai/codex)",
133                e
134            ))
135        })?;
136
137        let stdin = child
138            .stdin
139            .take()
140            .ok_or_else(|| ProviderError::RequestFailed("无法获取 app-server stdin".to_string()))?;
141
142        let stdout = child.stdout.take().ok_or_else(|| {
143            ProviderError::RequestFailed("无法获取 app-server stdout".to_string())
144        })?;
145
146        let stdout_reader = BufReader::new(stdout);
147
148        Ok(Self {
149            child,
150            stdin,
151            stdout_reader,
152            current_thread_id: None,
153            pending_responses: HashMap::new(),
154        })
155    }
156
157    /// 发送 JSON-RPC 请求
158    fn send_request(&mut self, method: &str, params: Value) -> Result<u64, ProviderError> {
159        let id = next_request_id();
160        let request = json!({
161            "method": method,
162            "id": id,
163            "params": params
164        });
165
166        let request_str = serde_json::to_string(&request)
167            .map_err(|e| ProviderError::RequestFailed(format!("序列化请求失败: {}", e)))?;
168
169        writeln!(self.stdin, "{}", request_str)
170            .map_err(|e| ProviderError::RequestFailed(format!("发送请求失败: {}", e)))?;
171
172        self.stdin
173            .flush()
174            .map_err(|e| ProviderError::RequestFailed(format!("刷新 stdin 失败: {}", e)))?;
175
176        tracing::debug!("发送请求: {} (id={})", method, id);
177        Ok(id)
178    }
179
180    /// 发送通知(无需响应)
181    fn send_notification(&mut self, method: &str, params: Value) -> Result<(), ProviderError> {
182        let notification = json!({
183            "method": method,
184            "params": params
185        });
186
187        let notification_str = serde_json::to_string(&notification)
188            .map_err(|e| ProviderError::RequestFailed(format!("序列化通知失败: {}", e)))?;
189
190        writeln!(self.stdin, "{}", notification_str)
191            .map_err(|e| ProviderError::RequestFailed(format!("发送通知失败: {}", e)))?;
192
193        self.stdin
194            .flush()
195            .map_err(|e| ProviderError::RequestFailed(format!("刷新 stdin 失败: {}", e)))?;
196
197        tracing::debug!("发送通知: {}", method);
198        Ok(())
199    }
200
201    /// 读取一行响应
202    fn read_line(&mut self) -> Result<String, ProviderError> {
203        let mut line = String::new();
204        self.stdout_reader
205            .read_line(&mut line)
206            .map_err(|e| ProviderError::RequestFailed(format!("读取响应失败: {}", e)))?;
207        Ok(line.trim().to_string())
208    }
209
210    /// 解析 JSON-RPC 消息
211    fn parse_message(&self, line: &str) -> Result<Value, ProviderError> {
212        serde_json::from_str(line).map_err(|e| {
213            ProviderError::RequestFailed(format!("解析 JSON 失败: {} (内容: {})", e, line))
214        })
215    }
216
217    /// 初始化连接
218    pub fn initialize(
219        &mut self,
220        client_name: &str,
221        client_version: &str,
222    ) -> Result<Value, ProviderError> {
223        let params = json!({
224            "clientInfo": {
225                "name": client_name,
226                "version": client_version
227            }
228        });
229
230        let id = self.send_request("initialize", params)?;
231
232        // 读取响应直到获得匹配的 result
233        loop {
234            let line = self.read_line()?;
235            if line.is_empty() {
236                continue;
237            }
238
239            let msg = self.parse_message(&line)?;
240
241            // 检查是否是我们的响应
242            if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
243                if msg_id == id {
244                    if let Some(error) = msg.get("error") {
245                        return Err(ProviderError::RequestFailed(format!(
246                            "initialize 失败: {}",
247                            error
248                        )));
249                    }
250                    let result = msg.get("result").cloned().unwrap_or(json!({}));
251
252                    // 发送 initialized 通知
253                    self.send_notification("initialized", json!({}))?;
254
255                    return Ok(result);
256                }
257            }
258        }
259    }
260
261    /// 启动新线程
262    pub fn thread_start(
263        &mut self,
264        model: Option<&str>,
265        cwd: Option<&str>,
266        approval_policy: Option<&str>,
267        sandbox: Option<&str>,
268    ) -> Result<ThreadInfo, ProviderError> {
269        let mut params = json!({});
270
271        if let Some(m) = model {
272            params["model"] = json!(m);
273        }
274        if let Some(dir) = cwd {
275            params["cwd"] = json!(dir);
276        }
277        if let Some(policy) = approval_policy {
278            params["approvalPolicy"] = json!(policy);
279        }
280        if let Some(sb) = sandbox {
281            params["sandbox"] = json!(sb);
282        }
283
284        let id = self.send_request("thread/start", params)?;
285
286        // 读取响应
287        loop {
288            let line = self.read_line()?;
289            if line.is_empty() {
290                continue;
291            }
292
293            let msg = self.parse_message(&line)?;
294
295            // 检查是否是我们的响应
296            if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
297                if msg_id == id {
298                    if let Some(error) = msg.get("error") {
299                        return Err(ProviderError::RequestFailed(format!(
300                            "thread/start 失败: {}",
301                            error
302                        )));
303                    }
304
305                    let thread: ThreadInfo = serde_json::from_value(
306                        msg.get("result")
307                            .and_then(|r| r.get("thread"))
308                            .cloned()
309                            .unwrap_or(json!({})),
310                    )
311                    .map_err(|e| {
312                        ProviderError::RequestFailed(format!("解析 thread 失败: {}", e))
313                    })?;
314
315                    self.current_thread_id = Some(thread.id.clone());
316                    return Ok(thread);
317                }
318            }
319
320            // 处理 thread/started 通知
321            if msg.get("method").and_then(|v| v.as_str()) == Some("thread/started") {
322                tracing::debug!("收到 thread/started 通知");
323            }
324        }
325    }
326
327    /// 恢复已有线程
328    pub fn thread_resume(&mut self, thread_id: &str) -> Result<(), ProviderError> {
329        let params = json!({
330            "thread_id": thread_id
331        });
332
333        let id = self.send_request("thread/resume", params)?;
334
335        // 读取响应
336        loop {
337            let line = self.read_line()?;
338            if line.is_empty() {
339                continue;
340            }
341
342            let msg = self.parse_message(&line)?;
343
344            if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
345                if msg_id == id {
346                    if let Some(error) = msg.get("error") {
347                        return Err(ProviderError::RequestFailed(format!(
348                            "thread/resume 失败: {}",
349                            error
350                        )));
351                    }
352
353                    self.current_thread_id = Some(thread_id.to_string());
354                    return Ok(());
355                }
356            }
357        }
358    }
359
360    /// 获取当前 thread ID
361    pub fn current_thread_id(&self) -> Option<&str> {
362        self.current_thread_id.as_deref()
363    }
364
365    /// 启动一个 turn 并收集所有事件
366    pub fn turn_start(
367        &mut self,
368        input_text: &str,
369        model: Option<&str>,
370        effort: Option<&str>,
371    ) -> Result<(String, Vec<AppServerEvent>), ProviderError> {
372        let thread_id = self.current_thread_id.clone().ok_or_else(|| {
373            ProviderError::RequestFailed("没有活动的 thread,请先调用 thread_start".to_string())
374        })?;
375
376        let mut params = json!({
377            "threadId": thread_id,
378            "input": [
379                { "type": "text", "text": input_text }
380            ]
381        });
382
383        if let Some(m) = model {
384            params["model"] = json!(m);
385        }
386        if let Some(e) = effort {
387            params["effort"] = json!(e);
388        }
389
390        let id = self.send_request("turn/start", params)?;
391
392        let mut events = Vec::new();
393        let mut accumulated_text = String::new();
394        let mut turn_completed = false;
395
396        // 读取事件流直到 turn 完成
397        while !turn_completed {
398            let line = self.read_line()?;
399            if line.is_empty() {
400                continue;
401            }
402
403            let msg = self.parse_message(&line)?;
404
405            // 检查是否是 turn/start 的响应
406            if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
407                if msg_id == id {
408                    if let Some(error) = msg.get("error") {
409                        return Err(ProviderError::RequestFailed(format!(
410                            "turn/start 失败: {}",
411                            error
412                        )));
413                    }
414                    // turn/start 响应只是确认,继续读取事件
415                    continue;
416                }
417            }
418
419            // 处理通知事件
420            if let Some(method) = msg.get("method").and_then(|v| v.as_str()) {
421                let params = msg.get("params").cloned().unwrap_or(json!({}));
422                let event = self.parse_event(method, &params, &mut accumulated_text);
423
424                match &event {
425                    AppServerEvent::TurnCompleted(_) => {
426                        turn_completed = true;
427                    }
428                    AppServerEvent::Error(e) => {
429                        tracing::error!("收到错误事件: {}", e);
430                    }
431                    _ => {}
432                }
433
434                events.push(event);
435            }
436        }
437
438        Ok((accumulated_text, events))
439    }
440
441    /// 解析事件
442    fn parse_event(
443        &self,
444        method: &str,
445        params: &Value,
446        accumulated_text: &mut String,
447    ) -> AppServerEvent {
448        match method {
449            "thread/started" => {
450                let thread: ThreadInfo =
451                    serde_json::from_value(params.get("thread").cloned().unwrap_or(json!({})))
452                        .unwrap_or(ThreadInfo {
453                            id: "unknown".to_string(),
454                            preview: None,
455                            model_provider: None,
456                            created_at: None,
457                        });
458                AppServerEvent::ThreadStarted(thread)
459            }
460
461            "turn/started" => {
462                let turn: TurnInfo =
463                    serde_json::from_value(params.get("turn").cloned().unwrap_or(json!({})))
464                        .unwrap_or(TurnInfo {
465                            id: "unknown".to_string(),
466                            status: "unknown".to_string(),
467                            items: vec![],
468                            error: None,
469                        });
470                AppServerEvent::TurnStarted(turn)
471            }
472
473            "item/started" => {
474                let item_id = params
475                    .get("item")
476                    .and_then(|i| i.get("id"))
477                    .and_then(|v| v.as_str())
478                    .unwrap_or("unknown")
479                    .to_string();
480                let item_type = params
481                    .get("item")
482                    .and_then(|i| i.get("type"))
483                    .and_then(|v| v.as_str())
484                    .unwrap_or("unknown")
485                    .to_string();
486                AppServerEvent::ItemStarted { item_id, item_type }
487            }
488
489            "item/agentMessage/delta" => {
490                let item_id = params
491                    .get("itemId")
492                    .and_then(|v| v.as_str())
493                    .unwrap_or("unknown")
494                    .to_string();
495                let text = params
496                    .get("delta")
497                    .and_then(|v| v.as_str())
498                    .unwrap_or("")
499                    .to_string();
500
501                // 累积文本
502                accumulated_text.push_str(&text);
503
504                AppServerEvent::AgentMessageDelta { item_id, text }
505            }
506
507            "item/reasoning/delta" => {
508                let item_id = params
509                    .get("itemId")
510                    .and_then(|v| v.as_str())
511                    .unwrap_or("unknown")
512                    .to_string();
513                let text = params
514                    .get("delta")
515                    .and_then(|v| v.as_str())
516                    .unwrap_or("")
517                    .to_string();
518                AppServerEvent::ReasoningDelta { item_id, text }
519            }
520
521            "item/completed" => {
522                let item_id = params
523                    .get("item")
524                    .and_then(|i| i.get("id"))
525                    .and_then(|v| v.as_str())
526                    .unwrap_or("unknown")
527                    .to_string();
528                AppServerEvent::ItemCompleted { item_id }
529            }
530
531            "turn/completed" => {
532                let turn: TurnInfo =
533                    serde_json::from_value(params.get("turn").cloned().unwrap_or(json!({})))
534                        .unwrap_or(TurnInfo {
535                            id: "unknown".to_string(),
536                            status: "completed".to_string(),
537                            items: vec![],
538                            error: None,
539                        });
540                AppServerEvent::TurnCompleted(turn)
541            }
542
543            "error" => {
544                let message = params
545                    .get("message")
546                    .and_then(|v| v.as_str())
547                    .unwrap_or("未知错误")
548                    .to_string();
549                AppServerEvent::Error(message)
550            }
551
552            _ => AppServerEvent::Unknown(params.clone()),
553        }
554    }
555
556    /// 中断当前 turn
557    pub fn turn_interrupt(&mut self) -> Result<(), ProviderError> {
558        let thread_id = self
559            .current_thread_id
560            .clone()
561            .ok_or_else(|| ProviderError::RequestFailed("没有活动的 thread".to_string()))?;
562
563        let params = json!({
564            "threadId": thread_id
565        });
566
567        self.send_notification("turn/interrupt", params)?;
568        Ok(())
569    }
570
571    /// 关闭连接
572    pub fn close(&mut self) -> Result<(), ProviderError> {
573        // 尝试优雅关闭
574        let _ = self.child.kill();
575        let _ = self.child.wait();
576        Ok(())
577    }
578
579    /// 检查进程是否还在运行
580    pub fn is_alive(&mut self) -> bool {
581        match self.child.try_wait() {
582            Ok(Some(_)) => false, // 进程已退出
583            Ok(None) => true,     // 进程仍在运行
584            Err(_) => false,      // 出错,假设已退出
585        }
586    }
587}
588
589impl Drop for CodexAppServerConnection {
590    fn drop(&mut self) {
591        let _ = self.close();
592    }
593}
594
595/// 会话管理器 - 管理多个 Codex app-server 连接
596pub struct CodexSessionManager {
597    /// 命令路径
598    command: PathBuf,
599    /// 活动连接 (conversation_id -> connection)
600    connections: Arc<Mutex<HashMap<String, CodexAppServerConnection>>>,
601    /// 会话映射 (conversation_id -> thread_id)
602    session_map: Arc<Mutex<HashMap<String, String>>>,
603}
604
605impl CodexSessionManager {
606    /// 创建新的会话管理器
607    pub fn new(command: PathBuf) -> Self {
608        Self {
609            command,
610            connections: Arc::new(Mutex::new(HashMap::new())),
611            session_map: Arc::new(Mutex::new(HashMap::new())),
612        }
613    }
614
615    /// 获取或创建连接
616    pub fn get_or_create_connection(
617        &self,
618        conversation_id: &str,
619        cwd: Option<&str>,
620        model: Option<&str>,
621    ) -> Result<(), ProviderError> {
622        let mut connections = self
623            .connections
624            .lock()
625            .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
626
627        // 检查是否已有连接
628        if let Some(conn) = connections.get_mut(conversation_id) {
629            if conn.is_alive() {
630                return Ok(());
631            }
632            // 连接已死,移除
633            connections.remove(conversation_id);
634        }
635
636        // 创建新连接
637        let mut conn = CodexAppServerConnection::spawn(&self.command, cwd)?;
638
639        // 初始化
640        conn.initialize("aster", env!("CARGO_PKG_VERSION"))?;
641
642        // 检查是否有已保存的 thread_id
643        let session_map = self
644            .session_map
645            .lock()
646            .map_err(|e| ProviderError::RequestFailed(format!("获取会话映射锁失败: {}", e)))?;
647
648        if let Some(thread_id) = session_map.get(conversation_id) {
649            // 尝试恢复会话
650            match conn.thread_resume(thread_id) {
651                Ok(_) => {
652                    tracing::info!("恢复会话成功: {} -> {}", conversation_id, thread_id);
653                }
654                Err(e) => {
655                    tracing::warn!("恢复会话失败,创建新会话: {}", e);
656                    drop(session_map);
657                    let thread =
658                        conn.thread_start(model, cwd, Some("never"), Some("workspaceWrite"))?;
659                    let mut session_map = self.session_map.lock().map_err(|e| {
660                        ProviderError::RequestFailed(format!("获取会话映射锁失败: {}", e))
661                    })?;
662                    session_map.insert(conversation_id.to_string(), thread.id);
663                }
664            }
665        } else {
666            drop(session_map);
667            // 创建新会话
668            let thread = conn.thread_start(model, cwd, Some("never"), Some("workspaceWrite"))?;
669            let mut session_map = self
670                .session_map
671                .lock()
672                .map_err(|e| ProviderError::RequestFailed(format!("获取会话映射锁失败: {}", e)))?;
673            session_map.insert(conversation_id.to_string(), thread.id);
674            tracing::info!(
675                "创建新会话: {} -> {}",
676                conversation_id,
677                session_map.get(conversation_id).unwrap()
678            );
679        }
680
681        connections.insert(conversation_id.to_string(), conn);
682        Ok(())
683    }
684
685    /// 发送消息并获取响应
686    pub fn send_message(
687        &self,
688        conversation_id: &str,
689        message: &str,
690        model: Option<&str>,
691        effort: Option<&str>,
692    ) -> Result<(String, Vec<AppServerEvent>), ProviderError> {
693        let mut connections = self
694            .connections
695            .lock()
696            .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
697
698        let conn = connections.get_mut(conversation_id).ok_or_else(|| {
699            ProviderError::RequestFailed(format!("会话不存在: {}", conversation_id))
700        })?;
701
702        conn.turn_start(message, model, effort)
703    }
704
705    /// 获取会话的 thread_id
706    pub fn get_thread_id(&self, conversation_id: &str) -> Option<String> {
707        self.session_map
708            .lock()
709            .ok()
710            .and_then(|map| map.get(conversation_id).cloned())
711    }
712
713    /// 关闭会话
714    pub fn close_session(&self, conversation_id: &str) -> Result<(), ProviderError> {
715        let mut connections = self
716            .connections
717            .lock()
718            .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
719
720        if let Some(mut conn) = connections.remove(conversation_id) {
721            conn.close()?;
722        }
723
724        Ok(())
725    }
726
727    /// 关闭所有会话
728    pub fn close_all(&self) -> Result<(), ProviderError> {
729        let mut connections = self
730            .connections
731            .lock()
732            .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
733
734        for (_, mut conn) in connections.drain() {
735            let _ = conn.close();
736        }
737
738        Ok(())
739    }
740}
741
742impl Drop for CodexSessionManager {
743    fn drop(&mut self) {
744        let _ = self.close_all();
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751
752    #[test]
753    fn test_request_id_generation() {
754        let id1 = next_request_id();
755        let id2 = next_request_id();
756        assert!(id2 > id1);
757    }
758
759    #[test]
760    fn test_thread_info_deserialize() {
761        let json = r#"{
762            "id": "thr_123",
763            "preview": "Test thread",
764            "modelProvider": "openai",
765            "createdAt": 1730910000
766        }"#;
767
768        let thread: ThreadInfo = serde_json::from_str(json).unwrap();
769        assert_eq!(thread.id, "thr_123");
770        assert_eq!(thread.preview, Some("Test thread".to_string()));
771        assert_eq!(thread.model_provider, Some("openai".to_string()));
772    }
773
774    #[test]
775    fn test_turn_info_deserialize() {
776        let json = r#"{
777            "id": "turn_456",
778            "status": "inProgress",
779            "items": [],
780            "error": null
781        }"#;
782
783        let turn: TurnInfo = serde_json::from_str(json).unwrap();
784        assert_eq!(turn.id, "turn_456");
785        assert_eq!(turn.status, "inProgress");
786        assert!(turn.items.is_empty());
787        assert!(turn.error.is_none());
788    }
789}