1use serde::{Deserialize, Serialize};
14use serde_json::{json, Value};
15use std::collections::HashMap;
16use std::io::{BufRead, BufReader, Write};
17use std::path::PathBuf;
18use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::{Arc, Mutex};
21
22use super::errors::ProviderError;
23
24static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
26
27fn next_request_id() -> u64 {
28 REQUEST_ID.fetch_add(1, Ordering::SeqCst)
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ThreadInfo {
34 pub id: String,
35 pub preview: Option<String>,
36 #[serde(rename = "modelProvider")]
37 pub model_provider: Option<String>,
38 #[serde(rename = "createdAt")]
39 pub created_at: Option<i64>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TurnInfo {
45 pub id: String,
46 pub status: String,
47 pub items: Vec<TurnItem>,
48 pub error: Option<String>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type")]
54pub enum TurnItem {
55 #[serde(rename = "agentMessage")]
56 AgentMessage {
57 id: String,
58 text: Option<String>,
59 #[serde(default)]
60 complete: bool,
61 },
62 #[serde(rename = "reasoning")]
63 Reasoning {
64 id: String,
65 text: Option<String>,
66 #[serde(default)]
67 complete: bool,
68 },
69 #[serde(rename = "toolCall")]
70 ToolCall {
71 id: String,
72 name: Option<String>,
73 #[serde(default)]
74 complete: bool,
75 },
76 #[serde(other)]
77 Unknown,
78}
79
80#[derive(Debug, Clone)]
82pub enum AppServerEvent {
83 ThreadStarted(ThreadInfo),
85 TurnStarted(TurnInfo),
87 ItemStarted { item_id: String, item_type: String },
89 AgentMessageDelta { item_id: String, text: String },
91 ReasoningDelta { item_id: String, text: String },
93 ItemCompleted { item_id: String },
95 TurnCompleted(TurnInfo),
97 Error(String),
99 Unknown(Value),
101}
102
103pub struct CodexAppServerConnection {
105 child: Child,
107 stdin: ChildStdin,
109 stdout_reader: BufReader<ChildStdout>,
111 current_thread_id: Option<String>,
113 pending_responses: HashMap<u64, tokio::sync::oneshot::Sender<Result<Value, ProviderError>>>,
115}
116
117impl CodexAppServerConnection {
118 pub fn spawn(command: &PathBuf, cwd: Option<&str>) -> Result<Self, ProviderError> {
120 let mut cmd = Command::new(command);
121 cmd.arg("app-server")
122 .stdin(Stdio::piped())
123 .stdout(Stdio::piped())
124 .stderr(Stdio::piped());
125
126 if let Some(dir) = cwd {
127 cmd.current_dir(dir);
128 }
129
130 let mut child = cmd.spawn().map_err(|e| {
131 ProviderError::RequestFailed(format!(
132 "无法启动 Codex app-server: {}. 请确保已安装 Codex CLI (npm i -g @openai/codex)",
133 e
134 ))
135 })?;
136
137 let stdin = child
138 .stdin
139 .take()
140 .ok_or_else(|| ProviderError::RequestFailed("无法获取 app-server stdin".to_string()))?;
141
142 let stdout = child.stdout.take().ok_or_else(|| {
143 ProviderError::RequestFailed("无法获取 app-server stdout".to_string())
144 })?;
145
146 let stdout_reader = BufReader::new(stdout);
147
148 Ok(Self {
149 child,
150 stdin,
151 stdout_reader,
152 current_thread_id: None,
153 pending_responses: HashMap::new(),
154 })
155 }
156
157 fn send_request(&mut self, method: &str, params: Value) -> Result<u64, ProviderError> {
159 let id = next_request_id();
160 let request = json!({
161 "method": method,
162 "id": id,
163 "params": params
164 });
165
166 let request_str = serde_json::to_string(&request)
167 .map_err(|e| ProviderError::RequestFailed(format!("序列化请求失败: {}", e)))?;
168
169 writeln!(self.stdin, "{}", request_str)
170 .map_err(|e| ProviderError::RequestFailed(format!("发送请求失败: {}", e)))?;
171
172 self.stdin
173 .flush()
174 .map_err(|e| ProviderError::RequestFailed(format!("刷新 stdin 失败: {}", e)))?;
175
176 tracing::debug!("发送请求: {} (id={})", method, id);
177 Ok(id)
178 }
179
180 fn send_notification(&mut self, method: &str, params: Value) -> Result<(), ProviderError> {
182 let notification = json!({
183 "method": method,
184 "params": params
185 });
186
187 let notification_str = serde_json::to_string(¬ification)
188 .map_err(|e| ProviderError::RequestFailed(format!("序列化通知失败: {}", e)))?;
189
190 writeln!(self.stdin, "{}", notification_str)
191 .map_err(|e| ProviderError::RequestFailed(format!("发送通知失败: {}", e)))?;
192
193 self.stdin
194 .flush()
195 .map_err(|e| ProviderError::RequestFailed(format!("刷新 stdin 失败: {}", e)))?;
196
197 tracing::debug!("发送通知: {}", method);
198 Ok(())
199 }
200
201 fn read_line(&mut self) -> Result<String, ProviderError> {
203 let mut line = String::new();
204 self.stdout_reader
205 .read_line(&mut line)
206 .map_err(|e| ProviderError::RequestFailed(format!("读取响应失败: {}", e)))?;
207 Ok(line.trim().to_string())
208 }
209
210 fn parse_message(&self, line: &str) -> Result<Value, ProviderError> {
212 serde_json::from_str(line).map_err(|e| {
213 ProviderError::RequestFailed(format!("解析 JSON 失败: {} (内容: {})", e, line))
214 })
215 }
216
217 pub fn initialize(
219 &mut self,
220 client_name: &str,
221 client_version: &str,
222 ) -> Result<Value, ProviderError> {
223 let params = json!({
224 "clientInfo": {
225 "name": client_name,
226 "version": client_version
227 }
228 });
229
230 let id = self.send_request("initialize", params)?;
231
232 loop {
234 let line = self.read_line()?;
235 if line.is_empty() {
236 continue;
237 }
238
239 let msg = self.parse_message(&line)?;
240
241 if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
243 if msg_id == id {
244 if let Some(error) = msg.get("error") {
245 return Err(ProviderError::RequestFailed(format!(
246 "initialize 失败: {}",
247 error
248 )));
249 }
250 let result = msg.get("result").cloned().unwrap_or(json!({}));
251
252 self.send_notification("initialized", json!({}))?;
254
255 return Ok(result);
256 }
257 }
258 }
259 }
260
261 pub fn thread_start(
263 &mut self,
264 model: Option<&str>,
265 cwd: Option<&str>,
266 approval_policy: Option<&str>,
267 sandbox: Option<&str>,
268 ) -> Result<ThreadInfo, ProviderError> {
269 let mut params = json!({});
270
271 if let Some(m) = model {
272 params["model"] = json!(m);
273 }
274 if let Some(dir) = cwd {
275 params["cwd"] = json!(dir);
276 }
277 if let Some(policy) = approval_policy {
278 params["approvalPolicy"] = json!(policy);
279 }
280 if let Some(sb) = sandbox {
281 params["sandbox"] = json!(sb);
282 }
283
284 let id = self.send_request("thread/start", params)?;
285
286 loop {
288 let line = self.read_line()?;
289 if line.is_empty() {
290 continue;
291 }
292
293 let msg = self.parse_message(&line)?;
294
295 if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
297 if msg_id == id {
298 if let Some(error) = msg.get("error") {
299 return Err(ProviderError::RequestFailed(format!(
300 "thread/start 失败: {}",
301 error
302 )));
303 }
304
305 let thread: ThreadInfo = serde_json::from_value(
306 msg.get("result")
307 .and_then(|r| r.get("thread"))
308 .cloned()
309 .unwrap_or(json!({})),
310 )
311 .map_err(|e| {
312 ProviderError::RequestFailed(format!("解析 thread 失败: {}", e))
313 })?;
314
315 self.current_thread_id = Some(thread.id.clone());
316 return Ok(thread);
317 }
318 }
319
320 if msg.get("method").and_then(|v| v.as_str()) == Some("thread/started") {
322 tracing::debug!("收到 thread/started 通知");
323 }
324 }
325 }
326
327 pub fn thread_resume(&mut self, thread_id: &str) -> Result<(), ProviderError> {
329 let params = json!({
330 "thread_id": thread_id
331 });
332
333 let id = self.send_request("thread/resume", params)?;
334
335 loop {
337 let line = self.read_line()?;
338 if line.is_empty() {
339 continue;
340 }
341
342 let msg = self.parse_message(&line)?;
343
344 if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
345 if msg_id == id {
346 if let Some(error) = msg.get("error") {
347 return Err(ProviderError::RequestFailed(format!(
348 "thread/resume 失败: {}",
349 error
350 )));
351 }
352
353 self.current_thread_id = Some(thread_id.to_string());
354 return Ok(());
355 }
356 }
357 }
358 }
359
360 pub fn current_thread_id(&self) -> Option<&str> {
362 self.current_thread_id.as_deref()
363 }
364
365 pub fn turn_start(
367 &mut self,
368 input_text: &str,
369 model: Option<&str>,
370 effort: Option<&str>,
371 ) -> Result<(String, Vec<AppServerEvent>), ProviderError> {
372 let thread_id = self.current_thread_id.clone().ok_or_else(|| {
373 ProviderError::RequestFailed("没有活动的 thread,请先调用 thread_start".to_string())
374 })?;
375
376 let mut params = json!({
377 "threadId": thread_id,
378 "input": [
379 { "type": "text", "text": input_text }
380 ]
381 });
382
383 if let Some(m) = model {
384 params["model"] = json!(m);
385 }
386 if let Some(e) = effort {
387 params["effort"] = json!(e);
388 }
389
390 let id = self.send_request("turn/start", params)?;
391
392 let mut events = Vec::new();
393 let mut accumulated_text = String::new();
394 let mut turn_completed = false;
395
396 while !turn_completed {
398 let line = self.read_line()?;
399 if line.is_empty() {
400 continue;
401 }
402
403 let msg = self.parse_message(&line)?;
404
405 if let Some(msg_id) = msg.get("id").and_then(|v| v.as_u64()) {
407 if msg_id == id {
408 if let Some(error) = msg.get("error") {
409 return Err(ProviderError::RequestFailed(format!(
410 "turn/start 失败: {}",
411 error
412 )));
413 }
414 continue;
416 }
417 }
418
419 if let Some(method) = msg.get("method").and_then(|v| v.as_str()) {
421 let params = msg.get("params").cloned().unwrap_or(json!({}));
422 let event = self.parse_event(method, ¶ms, &mut accumulated_text);
423
424 match &event {
425 AppServerEvent::TurnCompleted(_) => {
426 turn_completed = true;
427 }
428 AppServerEvent::Error(e) => {
429 tracing::error!("收到错误事件: {}", e);
430 }
431 _ => {}
432 }
433
434 events.push(event);
435 }
436 }
437
438 Ok((accumulated_text, events))
439 }
440
441 fn parse_event(
443 &self,
444 method: &str,
445 params: &Value,
446 accumulated_text: &mut String,
447 ) -> AppServerEvent {
448 match method {
449 "thread/started" => {
450 let thread: ThreadInfo =
451 serde_json::from_value(params.get("thread").cloned().unwrap_or(json!({})))
452 .unwrap_or(ThreadInfo {
453 id: "unknown".to_string(),
454 preview: None,
455 model_provider: None,
456 created_at: None,
457 });
458 AppServerEvent::ThreadStarted(thread)
459 }
460
461 "turn/started" => {
462 let turn: TurnInfo =
463 serde_json::from_value(params.get("turn").cloned().unwrap_or(json!({})))
464 .unwrap_or(TurnInfo {
465 id: "unknown".to_string(),
466 status: "unknown".to_string(),
467 items: vec![],
468 error: None,
469 });
470 AppServerEvent::TurnStarted(turn)
471 }
472
473 "item/started" => {
474 let item_id = params
475 .get("item")
476 .and_then(|i| i.get("id"))
477 .and_then(|v| v.as_str())
478 .unwrap_or("unknown")
479 .to_string();
480 let item_type = params
481 .get("item")
482 .and_then(|i| i.get("type"))
483 .and_then(|v| v.as_str())
484 .unwrap_or("unknown")
485 .to_string();
486 AppServerEvent::ItemStarted { item_id, item_type }
487 }
488
489 "item/agentMessage/delta" => {
490 let item_id = params
491 .get("itemId")
492 .and_then(|v| v.as_str())
493 .unwrap_or("unknown")
494 .to_string();
495 let text = params
496 .get("delta")
497 .and_then(|v| v.as_str())
498 .unwrap_or("")
499 .to_string();
500
501 accumulated_text.push_str(&text);
503
504 AppServerEvent::AgentMessageDelta { item_id, text }
505 }
506
507 "item/reasoning/delta" => {
508 let item_id = params
509 .get("itemId")
510 .and_then(|v| v.as_str())
511 .unwrap_or("unknown")
512 .to_string();
513 let text = params
514 .get("delta")
515 .and_then(|v| v.as_str())
516 .unwrap_or("")
517 .to_string();
518 AppServerEvent::ReasoningDelta { item_id, text }
519 }
520
521 "item/completed" => {
522 let item_id = params
523 .get("item")
524 .and_then(|i| i.get("id"))
525 .and_then(|v| v.as_str())
526 .unwrap_or("unknown")
527 .to_string();
528 AppServerEvent::ItemCompleted { item_id }
529 }
530
531 "turn/completed" => {
532 let turn: TurnInfo =
533 serde_json::from_value(params.get("turn").cloned().unwrap_or(json!({})))
534 .unwrap_or(TurnInfo {
535 id: "unknown".to_string(),
536 status: "completed".to_string(),
537 items: vec![],
538 error: None,
539 });
540 AppServerEvent::TurnCompleted(turn)
541 }
542
543 "error" => {
544 let message = params
545 .get("message")
546 .and_then(|v| v.as_str())
547 .unwrap_or("未知错误")
548 .to_string();
549 AppServerEvent::Error(message)
550 }
551
552 _ => AppServerEvent::Unknown(params.clone()),
553 }
554 }
555
556 pub fn turn_interrupt(&mut self) -> Result<(), ProviderError> {
558 let thread_id = self
559 .current_thread_id
560 .clone()
561 .ok_or_else(|| ProviderError::RequestFailed("没有活动的 thread".to_string()))?;
562
563 let params = json!({
564 "threadId": thread_id
565 });
566
567 self.send_notification("turn/interrupt", params)?;
568 Ok(())
569 }
570
571 pub fn close(&mut self) -> Result<(), ProviderError> {
573 let _ = self.child.kill();
575 let _ = self.child.wait();
576 Ok(())
577 }
578
579 pub fn is_alive(&mut self) -> bool {
581 match self.child.try_wait() {
582 Ok(Some(_)) => false, Ok(None) => true, Err(_) => false, }
586 }
587}
588
589impl Drop for CodexAppServerConnection {
590 fn drop(&mut self) {
591 let _ = self.close();
592 }
593}
594
595pub struct CodexSessionManager {
597 command: PathBuf,
599 connections: Arc<Mutex<HashMap<String, CodexAppServerConnection>>>,
601 session_map: Arc<Mutex<HashMap<String, String>>>,
603}
604
605impl CodexSessionManager {
606 pub fn new(command: PathBuf) -> Self {
608 Self {
609 command,
610 connections: Arc::new(Mutex::new(HashMap::new())),
611 session_map: Arc::new(Mutex::new(HashMap::new())),
612 }
613 }
614
615 pub fn get_or_create_connection(
617 &self,
618 conversation_id: &str,
619 cwd: Option<&str>,
620 model: Option<&str>,
621 ) -> Result<(), ProviderError> {
622 let mut connections = self
623 .connections
624 .lock()
625 .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
626
627 if let Some(conn) = connections.get_mut(conversation_id) {
629 if conn.is_alive() {
630 return Ok(());
631 }
632 connections.remove(conversation_id);
634 }
635
636 let mut conn = CodexAppServerConnection::spawn(&self.command, cwd)?;
638
639 conn.initialize("aster", env!("CARGO_PKG_VERSION"))?;
641
642 let session_map = self
644 .session_map
645 .lock()
646 .map_err(|e| ProviderError::RequestFailed(format!("获取会话映射锁失败: {}", e)))?;
647
648 if let Some(thread_id) = session_map.get(conversation_id) {
649 match conn.thread_resume(thread_id) {
651 Ok(_) => {
652 tracing::info!("恢复会话成功: {} -> {}", conversation_id, thread_id);
653 }
654 Err(e) => {
655 tracing::warn!("恢复会话失败,创建新会话: {}", e);
656 drop(session_map);
657 let thread =
658 conn.thread_start(model, cwd, Some("never"), Some("workspaceWrite"))?;
659 let mut session_map = self.session_map.lock().map_err(|e| {
660 ProviderError::RequestFailed(format!("获取会话映射锁失败: {}", e))
661 })?;
662 session_map.insert(conversation_id.to_string(), thread.id);
663 }
664 }
665 } else {
666 drop(session_map);
667 let thread = conn.thread_start(model, cwd, Some("never"), Some("workspaceWrite"))?;
669 let mut session_map = self
670 .session_map
671 .lock()
672 .map_err(|e| ProviderError::RequestFailed(format!("获取会话映射锁失败: {}", e)))?;
673 session_map.insert(conversation_id.to_string(), thread.id);
674 tracing::info!(
675 "创建新会话: {} -> {}",
676 conversation_id,
677 session_map.get(conversation_id).unwrap()
678 );
679 }
680
681 connections.insert(conversation_id.to_string(), conn);
682 Ok(())
683 }
684
685 pub fn send_message(
687 &self,
688 conversation_id: &str,
689 message: &str,
690 model: Option<&str>,
691 effort: Option<&str>,
692 ) -> Result<(String, Vec<AppServerEvent>), ProviderError> {
693 let mut connections = self
694 .connections
695 .lock()
696 .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
697
698 let conn = connections.get_mut(conversation_id).ok_or_else(|| {
699 ProviderError::RequestFailed(format!("会话不存在: {}", conversation_id))
700 })?;
701
702 conn.turn_start(message, model, effort)
703 }
704
705 pub fn get_thread_id(&self, conversation_id: &str) -> Option<String> {
707 self.session_map
708 .lock()
709 .ok()
710 .and_then(|map| map.get(conversation_id).cloned())
711 }
712
713 pub fn close_session(&self, conversation_id: &str) -> Result<(), ProviderError> {
715 let mut connections = self
716 .connections
717 .lock()
718 .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
719
720 if let Some(mut conn) = connections.remove(conversation_id) {
721 conn.close()?;
722 }
723
724 Ok(())
725 }
726
727 pub fn close_all(&self) -> Result<(), ProviderError> {
729 let mut connections = self
730 .connections
731 .lock()
732 .map_err(|e| ProviderError::RequestFailed(format!("获取连接锁失败: {}", e)))?;
733
734 for (_, mut conn) in connections.drain() {
735 let _ = conn.close();
736 }
737
738 Ok(())
739 }
740}
741
742impl Drop for CodexSessionManager {
743 fn drop(&mut self) {
744 let _ = self.close_all();
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751
752 #[test]
753 fn test_request_id_generation() {
754 let id1 = next_request_id();
755 let id2 = next_request_id();
756 assert!(id2 > id1);
757 }
758
759 #[test]
760 fn test_thread_info_deserialize() {
761 let json = r#"{
762 "id": "thr_123",
763 "preview": "Test thread",
764 "modelProvider": "openai",
765 "createdAt": 1730910000
766 }"#;
767
768 let thread: ThreadInfo = serde_json::from_str(json).unwrap();
769 assert_eq!(thread.id, "thr_123");
770 assert_eq!(thread.preview, Some("Test thread".to_string()));
771 assert_eq!(thread.model_provider, Some("openai".to_string()));
772 }
773
774 #[test]
775 fn test_turn_info_deserialize() {
776 let json = r#"{
777 "id": "turn_456",
778 "status": "inProgress",
779 "items": [],
780 "error": null
781 }"#;
782
783 let turn: TurnInfo = serde_json::from_str(json).unwrap();
784 assert_eq!(turn.id, "turn_456");
785 assert_eq!(turn.status, "inProgress");
786 assert!(turn.items.is_empty());
787 assert!(turn.error.is_none());
788 }
789}