1use std::collections::HashMap;
14#[cfg(windows)]
15use std::os::windows::process::CommandExt;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::Duration;
19
20use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
21use tokio::process::{Child, ChildStdin};
22use tokio::sync::{broadcast, oneshot, Mutex};
23
24use super::terminal_manager::TerminalManager;
25#[cfg(windows)]
26use super::CREATE_NO_WINDOW;
27use crate::trace::{
28 Contributor, TraceConversation, TraceEventType, TraceRecord, TraceTool, TraceWriter,
29};
30
31pub type NotificationSender = broadcast::Sender<serde_json::Value>;
33
34type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<serde_json::Value, String>>>>>;
36
37pub struct AcpProcess {
39 stdin: Arc<Mutex<ChildStdin>>,
40 child: Arc<Mutex<Option<Child>>>,
41 pending: PendingMap,
42 next_id: Arc<AtomicU64>,
43 alive: Arc<AtomicBool>,
44 notification_tx: NotificationSender,
45 display_name: String,
46 command: String,
48 _reader_handle: tokio::task::JoinHandle<()>,
49}
50
51impl AcpProcess {
52 pub async fn spawn(
57 command: &str,
58 args: &[&str],
59 cwd: &str,
60 notification_tx: NotificationSender,
61 display_name: &str,
62 our_session_id: &str,
63 ) -> Result<Self, String> {
64 tracing::info!(
65 "[AcpProcess:{}] Spawning: {} {} (cwd: {})",
66 display_name,
67 command,
68 args.join(" "),
69 cwd,
70 );
71
72 let resolved_command =
75 crate::shell_env::which(command).unwrap_or_else(|| command.to_string());
76
77 let mut command_builder = tokio::process::Command::new(&resolved_command);
78 command_builder
79 .args(args)
80 .current_dir(cwd)
81 .env("PATH", crate::shell_env::full_path())
82 .env("NODE_NO_READLINE", "1")
83 .stdin(std::process::Stdio::piped())
84 .stdout(std::process::Stdio::piped())
85 .stderr(std::process::Stdio::piped());
86
87 #[cfg(windows)]
88 command_builder
89 .as_std_mut()
90 .creation_flags(CREATE_NO_WINDOW);
91
92 if resolved_command.ends_with("codex-acp") && std::env::var_os("RUST_LOG").is_none() {
96 command_builder.env(
97 "RUST_LOG",
98 "info,codex_acp::thread=info,codex_acp::codex_agent=info",
99 );
100 }
101
102 let mut child = command_builder.spawn().map_err(|e| {
103 format!(
104 "Failed to spawn '{}' (resolved: '{}'): {}. Is it installed and in PATH?",
105 command, resolved_command, e
106 )
107 })?;
108
109 let stdin = child
110 .stdin
111 .take()
112 .ok_or_else(|| "No stdin on child process".to_string())?;
113 let stdout = child
114 .stdout
115 .take()
116 .ok_or_else(|| "No stdout on child process".to_string())?;
117 let stderr = child.stderr.take();
118
119 let alive = Arc::new(AtomicBool::new(true));
120 let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
121 let stdin = Arc::new(Mutex::new(stdin));
122
123 let name = display_name.to_string();
124
125 if let Some(stderr) = stderr {
127 let name_clone = name.clone();
128 let ntx_stderr = notification_tx.clone();
129 let our_sid_stderr = our_session_id.to_string();
130 tokio::spawn(async move {
131 let reader = BufReader::new(stderr);
132 let mut lines = reader.lines();
133 while let Ok(Some(line)) = lines.next_line().await {
134 if !line.trim().is_empty() {
135 tracing::debug!("[AcpProcess:{} stderr] {}", name_clone, line);
136 let notification = serde_json::json!({
138 "jsonrpc": "2.0",
139 "method": "session/update",
140 "params": {
141 "sessionId": our_sid_stderr,
142 "update": {
143 "sessionUpdate": "process_output",
144 "source": "stderr",
145 "data": format!("{}\n", line),
146 "displayName": name_clone,
147 }
148 }
149 });
150 let _ = ntx_stderr.send(notification);
151 }
152 }
153 });
154 }
155
156 let alive_clone = alive.clone();
158 let pending_clone = pending.clone();
159 let ntx = notification_tx.clone();
160 let stdin_clone = stdin.clone();
161 let name_clone = name.clone();
162 let our_sid = our_session_id.to_string();
163 let cwd_clone = cwd.to_string();
164 let provider_clone = display_name.to_string();
165
166 let reader_handle = tokio::spawn(async move {
167 let reader = BufReader::new(stdout);
168 let mut lines = reader.lines();
169 let mut agent_msg_buffer = String::new();
171 let mut agent_thought_buffer = String::new();
173 let mut pending_tool_calls: std::collections::HashMap<String, (String, bool)> =
175 std::collections::HashMap::new();
176
177 while let Ok(Some(line)) = lines.next_line().await {
178 let line = line.trim().to_string();
179 if line.is_empty() {
180 continue;
181 }
182
183 let msg: serde_json::Value = match serde_json::from_str(&line) {
184 Ok(v) => v,
185 Err(_) => {
186 if let Some(v) = try_parse_embedded_json(&line) {
188 v
189 } else {
190 tracing::debug!(
191 "[AcpProcess:{}] Non-JSON stdout: {}",
192 name_clone,
193 &line[..line.len().min(200)]
194 );
195 continue;
196 }
197 }
198 };
199
200 let has_id = msg.get("id").is_some() && !msg.get("id").unwrap().is_null();
201 let has_result = msg.get("result").is_some();
202 let has_error = msg.get("error").is_some();
203 let has_method = msg.get("method").and_then(|m| m.as_str()).is_some();
204
205 if has_id && (has_result || has_error) {
206 let id = msg["id"].as_u64().unwrap_or(0);
208 let mut map = pending_clone.lock().await;
209 if let Some(tx) = map.remove(&id) {
210 if has_error {
211 let err_msg =
212 msg["error"]["message"].as_str().unwrap_or("unknown error");
213 let err_code = msg["error"]["code"].as_i64().unwrap_or(0);
214 let _ = tx.send(Err(format!("ACP Error [{}]: {}", err_code, err_msg)));
215 } else {
216 let _ = tx.send(Ok(msg["result"].clone()));
217 }
218 }
219 } else if has_id && has_method {
220 let method = msg["method"].as_str().unwrap_or("");
222 let id_val = msg["id"].clone();
223 tracing::info!(
224 "[AcpProcess:{}] Agent request: {} (id={})",
225 name_clone,
226 method,
227 id_val
228 );
229 let response =
230 handle_agent_request(method, &msg["params"], &our_sid, &ntx).await;
231 let reply = serde_json::json!({
232 "jsonrpc": "2.0",
233 "id": id_val,
234 "result": response,
235 });
236 let data = format!("{}\n", serde_json::to_string(&reply).unwrap());
237 let mut stdin = stdin_clone.lock().await;
238 let _ = stdin.write_all(data.as_bytes()).await;
239 let _ = stdin.flush().await;
240 } else if has_method {
241 let mut rewritten = msg.clone();
245 if let Some(params) = rewritten.get_mut("params") {
246 if params.get("sessionId").is_some() {
247 params["sessionId"] = serde_json::Value::String(our_sid.clone());
248 }
249 }
250
251 if let Some(params) = msg.get("params") {
253 if let Some(update) = params.get("update") {
254 let session_update = update
255 .get("sessionUpdate")
256 .and_then(|v| v.as_str())
257 .unwrap_or("");
258
259 match session_update {
260 "agent_thought_chunk" => {
261 let text = update
263 .get("content")
264 .and_then(|c| c.get("text"))
265 .and_then(|t| t.as_str())
266 .unwrap_or("");
267 agent_thought_buffer.push_str(text);
268 if agent_thought_buffer.len() >= 100 {
270 let record = TraceRecord::new(
271 &our_sid,
272 TraceEventType::AgentThought,
273 Contributor::new(&provider_clone, None),
274 )
275 .with_conversation(TraceConversation {
276 turn: None,
277 role: Some("assistant".to_string()),
278 content_preview: Some(
279 agent_thought_buffer
280 [..agent_thought_buffer.len().min(200)]
281 .to_string(),
282 ),
283 full_content: Some(agent_thought_buffer.clone()),
284 });
285 let writer = TraceWriter::new(&cwd_clone);
286 let _ = writer.append_safe(&record).await;
287 agent_thought_buffer.clear();
288 }
289 }
290 "agent_message_chunk" => {
291 let text = update
293 .get("content")
294 .and_then(|c| c.get("text"))
295 .and_then(|t| t.as_str())
296 .unwrap_or("");
297 agent_msg_buffer.push_str(text);
298 if agent_msg_buffer.len() >= 100 {
300 let record = TraceRecord::new(
301 &our_sid,
302 TraceEventType::AgentMessage,
303 Contributor::new(&provider_clone, None),
304 )
305 .with_conversation(TraceConversation {
306 turn: None,
307 role: Some("assistant".to_string()),
308 content_preview: Some(
309 agent_msg_buffer[..agent_msg_buffer.len().min(200)]
310 .to_string(),
311 ),
312 full_content: Some(agent_msg_buffer.clone()),
313 });
314 let writer = TraceWriter::new(&cwd_clone);
315 let _ = writer.append_safe(&record).await;
316 agent_msg_buffer.clear();
317 }
318 }
319 "agent_message" => {
320 let text = update
322 .get("content")
323 .and_then(|c| c.get("text"))
324 .and_then(|t| t.as_str())
325 .unwrap_or("");
326 let record = TraceRecord::new(
327 &our_sid,
328 TraceEventType::AgentMessage,
329 Contributor::new(&provider_clone, None),
330 )
331 .with_conversation(TraceConversation {
332 turn: None,
333 role: Some("assistant".to_string()),
334 content_preview: Some(
335 text[..text.len().min(200)].to_string(),
336 ),
337 full_content: Some(text.to_string()),
338 });
339 let writer = TraceWriter::new(&cwd_clone);
340 let _ = writer.append_safe(&record).await;
341 }
342 "tool_call" => {
343 let tool_call_id =
345 update.get("toolCallId").and_then(|v| v.as_str());
346 let kind = update
347 .get("kind")
348 .and_then(|v| v.as_str())
349 .or_else(|| update.get("title").and_then(|v| v.as_str()))
350 .unwrap_or("unknown");
351 let raw_input = update.get("rawInput").cloned();
352
353 let has_input = raw_input.as_ref().is_some_and(|v| {
355 if let Some(obj) = v.as_object() {
356 !obj.is_empty()
357 } else {
358 !v.is_null()
359 }
360 });
361
362 if has_input {
363 let record = TraceRecord::new(
365 &our_sid,
366 TraceEventType::ToolCall,
367 Contributor::new(&provider_clone, None),
368 )
369 .with_tool(TraceTool {
370 name: kind.to_string(),
371 tool_call_id: tool_call_id.map(|s| s.to_string()),
372 status: Some("running".to_string()),
373 input: raw_input,
374 output: None,
375 });
376 let writer = TraceWriter::new(&cwd_clone);
377 let _ = writer.append_safe(&record).await;
378 } else if let Some(id) = tool_call_id {
379 pending_tool_calls
381 .insert(id.to_string(), (kind.to_string(), false));
382 }
383 }
384 "tool_call_update" => {
385 let tool_call_id =
387 update.get("toolCallId").and_then(|v| v.as_str());
388 let kind = update
389 .get("kind")
390 .and_then(|v| v.as_str())
391 .or_else(|| update.get("title").and_then(|v| v.as_str()))
392 .unwrap_or("unknown");
393 let raw_input = update.get("rawInput").cloned();
394 let raw_output = update
395 .get("rawOutput")
396 .and_then(|v| v.as_str())
397 .map(|s| serde_json::Value::String(s.to_string()))
398 .or_else(|| update.get("rawOutput").cloned());
399 let status = update
400 .get("status")
401 .and_then(|v| v.as_str())
402 .unwrap_or("completed");
403
404 let has_input = raw_input.as_ref().is_some_and(|v| {
406 if let Some(obj) = v.as_object() {
407 !obj.is_empty()
408 } else {
409 !v.is_null()
410 }
411 });
412
413 if let Some(id) = tool_call_id {
414 if let Some((stored_kind, traced)) =
415 pending_tool_calls.get_mut(id)
416 {
417 if has_input && !*traced {
418 let record = TraceRecord::new(
420 &our_sid,
421 TraceEventType::ToolCall,
422 Contributor::new(&provider_clone, None),
423 )
424 .with_tool(TraceTool {
425 name: stored_kind.clone(),
426 tool_call_id: Some(id.to_string()),
427 status: Some("running".to_string()),
428 input: raw_input.clone(),
429 output: None,
430 });
431 let writer = TraceWriter::new(&cwd_clone);
432 let _ = writer.append_safe(&record).await;
433 *traced = true;
434 }
435 }
436 }
437
438 let is_complete = status == "completed"
440 || status == "failed"
441 || raw_output.is_some();
442 if is_complete {
443 let record = TraceRecord::new(
444 &our_sid,
445 TraceEventType::ToolResult,
446 Contributor::new(&provider_clone, None),
447 )
448 .with_tool(TraceTool {
449 name: kind.to_string(),
450 tool_call_id: tool_call_id.map(|s| s.to_string()),
451 status: Some(status.to_string()),
452 input: None,
453 output: raw_output,
454 });
455 let writer = TraceWriter::new(&cwd_clone);
456 let _ = writer.append_safe(&record).await;
457
458 if let Some(id) = tool_call_id {
460 pending_tool_calls.remove(id);
461 }
462 }
463 }
464 _ => {}
465 }
466 }
467 }
468
469 let _ = ntx.send(rewritten);
470 } else {
471 tracing::debug!(
472 "[AcpProcess:{}] Unhandled message: {}",
473 name_clone,
474 &line[..line.len().min(200)]
475 );
476 }
477 }
478
479 if !agent_msg_buffer.is_empty() {
481 let record = TraceRecord::new(
482 &our_sid,
483 TraceEventType::AgentMessage,
484 Contributor::new(&provider_clone, None),
485 )
486 .with_conversation(TraceConversation {
487 turn: None,
488 role: Some("assistant".to_string()),
489 content_preview: Some(
490 agent_msg_buffer[..agent_msg_buffer.len().min(200)].to_string(),
491 ),
492 full_content: Some(agent_msg_buffer.clone()),
493 });
494 let writer = TraceWriter::new(&cwd_clone);
495 let _ = writer.append_safe(&record).await;
496 }
497
498 if !agent_thought_buffer.is_empty() {
500 let record = TraceRecord::new(
501 &our_sid,
502 TraceEventType::AgentThought,
503 Contributor::new(&provider_clone, None),
504 )
505 .with_conversation(TraceConversation {
506 turn: None,
507 role: Some("assistant".to_string()),
508 content_preview: Some(
509 agent_thought_buffer[..agent_thought_buffer.len().min(200)].to_string(),
510 ),
511 full_content: Some(agent_thought_buffer.clone()),
512 });
513 let writer = TraceWriter::new(&cwd_clone);
514 let _ = writer.append_safe(&record).await;
515 }
516
517 alive_clone.store(false, Ordering::SeqCst);
518 tracing::info!("[AcpProcess:{}] stdout reader finished", name_clone);
519 });
520
521 tokio::time::sleep(Duration::from_millis(300)).await;
523
524 if !alive.load(Ordering::SeqCst) {
525 return Err(format!("{} process died during startup", display_name));
526 }
527
528 tracing::info!("[AcpProcess:{}] Process started", display_name);
529
530 Ok(Self {
531 stdin,
532 child: Arc::new(Mutex::new(Some(child))),
533 pending,
534 next_id: Arc::new(AtomicU64::new(1)),
535 alive,
536 notification_tx,
537 display_name: display_name.to_string(),
538 command: command.to_string(),
539 _reader_handle: reader_handle,
540 })
541 }
542
543 pub fn is_alive(&self) -> bool {
545 self.alive.load(Ordering::SeqCst)
546 }
547
548 pub async fn send_request(
550 &self,
551 method: &str,
552 params: serde_json::Value,
553 timeout_ms: Option<u64>,
554 ) -> Result<serde_json::Value, String> {
555 if !self.is_alive() {
556 return Err(format!("{} process is not alive", self.display_name));
557 }
558
559 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
560 let (tx, rx) = oneshot::channel();
561
562 self.pending.lock().await.insert(id, tx);
563
564 let msg = serde_json::json!({
565 "jsonrpc": "2.0",
566 "id": id,
567 "method": method,
568 "params": params,
569 });
570 let data = format!("{}\n", serde_json::to_string(&msg).unwrap());
571
572 {
573 let mut stdin = self.stdin.lock().await;
574 stdin
575 .write_all(data.as_bytes())
576 .await
577 .map_err(|e| format!("Write {}: {}", method, e))?;
578 stdin
579 .flush()
580 .await
581 .map_err(|e| format!("Flush {}: {}", method, e))?;
582 }
583
584 let is_npx_or_uvx = self.command == "npx" || self.command == "uvx";
587 let default_timeout = match method {
588 "initialize" | "session/new" => {
589 if is_npx_or_uvx {
590 120_000 } else {
592 15_000 }
594 }
595 "session/prompt" => 300_000, _ => 30_000,
597 };
598 let timeout_dur = Duration::from_millis(timeout_ms.unwrap_or(default_timeout));
599
600 match tokio::time::timeout(timeout_dur, rx).await {
601 Ok(Ok(result)) => result,
602 Ok(Err(_)) => Err(format!("Channel closed for {} (id={})", method, id)),
603 Err(_) => {
604 self.pending.lock().await.remove(&id);
605 Err(format!(
606 "Timeout waiting for {} (id={}, {}ms)",
607 method,
608 id,
609 timeout_dur.as_millis()
610 ))
611 }
612 }
613 }
614
615 pub async fn initialize(&self) -> Result<serde_json::Value, String> {
617 self.initialize_with_timeout(None).await
618 }
619
620 pub async fn initialize_with_timeout(
622 &self,
623 timeout_ms: Option<u64>,
624 ) -> Result<serde_json::Value, String> {
625 let result = self
626 .send_request(
627 "initialize",
628 serde_json::json!({
629 "protocolVersion": 1,
630 "clientInfo": {
631 "name": "routa-desktop",
632 "version": "0.1.0"
633 }
634 }),
635 timeout_ms,
636 )
637 .await?;
638 tracing::info!(
639 "[AcpProcess:{}] Initialized: {}",
640 self.display_name,
641 serde_json::to_string(&result).unwrap_or_default()
642 );
643 Ok(result)
644 }
645
646 pub async fn new_session(&self, cwd: &str) -> Result<String, String> {
648 let result = self
649 .send_request(
650 "session/new",
651 serde_json::json!({
652 "cwd": cwd,
653 "mcpServers": []
654 }),
655 None,
656 )
657 .await?;
658
659 let session_id = result["sessionId"]
660 .as_str()
661 .ok_or_else(|| "No sessionId in session/new response".to_string())?
662 .to_string();
663
664 tracing::info!(
665 "[AcpProcess:{}] Session created: {}",
666 self.display_name,
667 session_id
668 );
669 Ok(session_id)
670 }
671
672 pub async fn prompt(&self, session_id: &str, text: &str) -> Result<serde_json::Value, String> {
674 self.send_request(
675 "session/prompt",
676 serde_json::json!({
677 "sessionId": session_id,
678 "prompt": [{ "type": "text", "text": text }]
679 }),
680 Some(300_000),
681 )
682 .await
683 }
684
685 pub async fn cancel(&self, session_id: &str) {
687 let msg = serde_json::json!({
688 "jsonrpc": "2.0",
689 "method": "session/cancel",
690 "params": { "sessionId": session_id }
691 });
692 let data = format!("{}\n", serde_json::to_string(&msg).unwrap());
693 let mut stdin = self.stdin.lock().await;
694 let _ = stdin.write_all(data.as_bytes()).await;
695 let _ = stdin.flush().await;
696 }
697
698 pub fn notification_sender(&self) -> &NotificationSender {
700 &self.notification_tx
701 }
702
703 pub async fn kill(&self) {
705 self.alive.store(false, Ordering::SeqCst);
706 if let Some(mut child) = self.child.lock().await.take() {
707 tracing::info!("[AcpProcess:{}] Killing process", self.display_name);
708 let _ = child.kill().await;
709 }
710 let mut map = self.pending.lock().await;
712 for (_, tx) in map.drain() {
713 let _ = tx.send(Err("Process killed".to_string()));
714 }
715 }
716}
717
718async fn handle_agent_request(
720 method: &str,
721 params: &serde_json::Value,
722 session_id: &str,
723 notification_tx: &NotificationSender,
724) -> serde_json::Value {
725 match method {
726 "session/request_permission" => {
727 serde_json::json!({
729 "outcome": { "outcome": "approved" }
730 })
731 }
732 "fs/read_text_file" => {
733 let path = params["path"].as_str().unwrap_or("");
734 match tokio::fs::read_to_string(path).await {
735 Ok(content) => serde_json::json!({ "content": content }),
736 Err(e) => serde_json::json!({
737 "error": format!("Failed to read file: {}", e)
738 }),
739 }
740 }
741 "fs/write_text_file" => {
742 let path = params["path"].as_str().unwrap_or("");
743 let content = params["content"].as_str().unwrap_or("");
744 if let Some(parent) = std::path::Path::new(path).parent() {
745 let _ = tokio::fs::create_dir_all(parent).await;
746 }
747 match tokio::fs::write(path, content).await {
748 Ok(_) => serde_json::json!({}),
749 Err(e) => serde_json::json!({
750 "error": format!("Failed to write file: {}", e)
751 }),
752 }
753 }
754 "terminal/create" => {
755 match TerminalManager::global()
756 .create(params, session_id, notification_tx)
757 .await
758 {
759 Ok(result) => result,
760 Err(error) => serde_json::json!({ "error": error }),
761 }
762 }
763 "terminal/output" => {
764 let terminal_id = params["terminalId"].as_str().unwrap_or("");
765 match TerminalManager::global().get_output(terminal_id).await {
766 Ok(result) => result,
767 Err(error) => serde_json::json!({ "error": error }),
768 }
769 }
770 "terminal/wait_for_exit" => {
771 let terminal_id = params["terminalId"].as_str().unwrap_or("");
772 match TerminalManager::global().wait_for_exit(terminal_id).await {
773 Ok(result) => result,
774 Err(error) => serde_json::json!({ "error": error }),
775 }
776 }
777 "terminal/kill" => {
778 let terminal_id = params["terminalId"].as_str().unwrap_or("");
779 match TerminalManager::global().kill(terminal_id).await {
780 Ok(_) => serde_json::json!({}),
781 Err(error) => serde_json::json!({ "error": error }),
782 }
783 }
784 "terminal/release" => {
785 let terminal_id = params["terminalId"].as_str().unwrap_or("");
786 TerminalManager::global().release(terminal_id).await;
787 serde_json::json!({})
788 }
789 _ => {
790 tracing::warn!("[AcpProcess] Unknown agent request: {}", method);
791 serde_json::json!({})
792 }
793 }
794}
795
796fn try_parse_embedded_json(line: &str) -> Option<serde_json::Value> {
798 let mut depth = 0i32;
799 let mut start = None;
800
801 for (i, ch) in line.char_indices() {
802 match ch {
803 '{' => {
804 if depth == 0 {
805 start = Some(i);
806 }
807 depth += 1;
808 }
809 '}' => {
810 depth -= 1;
811 if depth == 0 {
812 if let Some(s) = start {
813 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&line[s..=i]) {
814 return Some(v);
815 }
816 }
817 start = None;
818 }
819 }
820 _ => {}
821 }
822 }
823 None
824}