Skip to main content

routa_core/acp/
terminal_manager.rs

1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::process::{Child, ChildStdin, Command};
6use tokio::sync::Mutex;
7
8use super::process::NotificationSender;
9
10#[derive(Clone)]
11struct ManagedTerminal {
12    terminal_id: String,
13    session_id: String,
14    child: Arc<Mutex<Child>>,
15    stdin: Arc<Mutex<Option<ChildStdin>>>,
16    output: Arc<Mutex<String>>,
17    exit_code: Arc<Mutex<Option<i32>>>,
18    cols: Arc<Mutex<Option<u16>>>,
19    rows: Arc<Mutex<Option<u16>>>,
20}
21
22#[derive(Clone, Default)]
23pub struct TerminalManager {
24    terminals: Arc<Mutex<HashMap<String, ManagedTerminal>>>,
25    counter: Arc<Mutex<u64>>,
26}
27
28impl TerminalManager {
29    pub fn global() -> &'static Self {
30        static INSTANCE: OnceLock<TerminalManager> = OnceLock::new();
31        INSTANCE.get_or_init(TerminalManager::default)
32    }
33
34    pub async fn create(
35        &self,
36        params: &serde_json::Value,
37        session_id: &str,
38        notification_tx: &NotificationSender,
39    ) -> Result<serde_json::Value, String> {
40        let terminal_id = {
41            let mut counter = self.counter.lock().await;
42            *counter += 1;
43            format!(
44                "term-{}-{}",
45                *counter,
46                chrono::Utc::now().timestamp_millis()
47            )
48        };
49
50        let command = params
51            .get("command")
52            .and_then(|value| value.as_str())
53            .unwrap_or("/bin/bash");
54        let args = params
55            .get("args")
56            .and_then(|value| value.as_array())
57            .map(|items| {
58                items
59                    .iter()
60                    .filter_map(|item| item.as_str().map(str::to_string))
61                    .collect::<Vec<_>>()
62            })
63            .unwrap_or_default();
64        let cwd = params
65            .get("cwd")
66            .and_then(|value| value.as_str())
67            .unwrap_or(".");
68        let cols = params
69            .get("cols")
70            .and_then(|value| value.as_u64())
71            .map(|value| value as u16);
72        let rows = params
73            .get("rows")
74            .and_then(|value| value.as_u64())
75            .map(|value| value as u16);
76
77        let mut command_builder = Command::new(command);
78        command_builder
79            .args(&args)
80            .current_dir(cwd)
81            .env("PATH", crate::shell_env::full_path())
82            .env("TERM", "xterm-256color")
83            .env("FORCE_COLOR", "1")
84            .stdin(std::process::Stdio::piped())
85            .stdout(std::process::Stdio::piped())
86            .stderr(std::process::Stdio::piped());
87
88        let mut child = command_builder
89            .spawn()
90            .map_err(|error| format!("Failed to spawn terminal process: {error}"))?;
91
92        let stdin = child.stdin.take();
93        let stdout = child.stdout.take();
94        let stderr = child.stderr.take();
95
96        let managed = ManagedTerminal {
97            terminal_id: terminal_id.clone(),
98            session_id: session_id.to_string(),
99            child: Arc::new(Mutex::new(child)),
100            stdin: Arc::new(Mutex::new(stdin)),
101            output: Arc::new(Mutex::new(String::new())),
102            exit_code: Arc::new(Mutex::new(None)),
103            cols: Arc::new(Mutex::new(cols)),
104            rows: Arc::new(Mutex::new(rows)),
105        };
106
107        self.terminals
108            .lock()
109            .await
110            .insert(terminal_id.clone(), managed.clone());
111
112        emit_terminal_update(
113            notification_tx,
114            session_id,
115            serde_json::json!({
116                "sessionUpdate": "terminal_created",
117                "terminalId": terminal_id,
118                "command": command,
119                "args": args,
120            }),
121        );
122
123        if let Some(stdout) = stdout {
124            spawn_output_forwarder(managed.clone(), stdout, notification_tx.clone());
125        }
126        if let Some(stderr) = stderr {
127            spawn_output_forwarder(managed.clone(), stderr, notification_tx.clone());
128        }
129        spawn_exit_watcher(managed, notification_tx.clone());
130
131        Ok(serde_json::json!({ "terminalId": terminal_id }))
132    }
133
134    pub async fn has_terminal(&self, session_id: &str, terminal_id: &str) -> bool {
135        self.terminals
136            .lock()
137            .await
138            .get(terminal_id)
139            .map(|terminal| terminal.session_id == session_id)
140            .unwrap_or(false)
141    }
142
143    pub async fn write(&self, terminal_id: &str, data: &str) -> Result<(), String> {
144        let terminal = self
145            .terminals
146            .lock()
147            .await
148            .get(terminal_id)
149            .cloned()
150            .ok_or_else(|| "Terminal not found".to_string())?;
151        let mut stdin_guard = terminal.stdin.lock().await;
152        let stdin = stdin_guard
153            .as_mut()
154            .ok_or_else(|| "Terminal is not writable".to_string())?;
155        stdin
156            .write_all(data.as_bytes())
157            .await
158            .map_err(|error| format!("Failed to write terminal input: {error}"))?;
159        stdin
160            .flush()
161            .await
162            .map_err(|error| format!("Failed to flush terminal input: {error}"))?;
163        Ok(())
164    }
165
166    pub async fn resize(
167        &self,
168        terminal_id: &str,
169        cols: Option<u16>,
170        rows: Option<u16>,
171    ) -> Result<(), String> {
172        let terminal = self
173            .terminals
174            .lock()
175            .await
176            .get(terminal_id)
177            .cloned()
178            .ok_or_else(|| "Terminal not found".to_string())?;
179        if let Some(cols) = cols {
180            *terminal.cols.lock().await = Some(cols);
181        }
182        if let Some(rows) = rows {
183            *terminal.rows.lock().await = Some(rows);
184        }
185        Ok(())
186    }
187
188    pub async fn get_output(&self, terminal_id: &str) -> Result<serde_json::Value, String> {
189        let terminal = self
190            .terminals
191            .lock()
192            .await
193            .get(terminal_id)
194            .cloned()
195            .ok_or_else(|| "Terminal not found".to_string())?;
196        let output = terminal.output.lock().await.clone();
197        Ok(serde_json::json!({ "output": output }))
198    }
199
200    pub async fn wait_for_exit(&self, terminal_id: &str) -> Result<serde_json::Value, String> {
201        let terminal = self
202            .terminals
203            .lock()
204            .await
205            .get(terminal_id)
206            .cloned()
207            .ok_or_else(|| "Terminal not found".to_string())?;
208        loop {
209            if let Some(code) = *terminal.exit_code.lock().await {
210                return Ok(serde_json::json!({ "exitCode": code }));
211            }
212            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
213        }
214    }
215
216    pub async fn kill(&self, terminal_id: &str) -> Result<(), String> {
217        let terminal = self
218            .terminals
219            .lock()
220            .await
221            .get(terminal_id)
222            .cloned()
223            .ok_or_else(|| "Terminal not found".to_string())?;
224        let mut child = terminal.child.lock().await;
225        child
226            .kill()
227            .await
228            .map_err(|error| format!("Failed to kill terminal: {error}"))
229    }
230
231    pub async fn release(&self, terminal_id: &str) {
232        self.terminals.lock().await.remove(terminal_id);
233    }
234}
235
236fn emit_terminal_update(
237    notification_tx: &NotificationSender,
238    session_id: &str,
239    update: serde_json::Value,
240) {
241    let _ = notification_tx.send(serde_json::json!({
242        "jsonrpc": "2.0",
243        "method": "session/update",
244        "params": {
245            "sessionId": session_id,
246            "update": update,
247        }
248    }));
249}
250
251fn spawn_output_forwarder<R>(
252    terminal: ManagedTerminal,
253    mut reader: R,
254    notification_tx: NotificationSender,
255) where
256    R: tokio::io::AsyncRead + Unpin + Send + 'static,
257{
258    tokio::spawn(async move {
259        let mut buffer = [0u8; 4096];
260        loop {
261            match reader.read(&mut buffer).await {
262                Ok(0) => break,
263                Ok(size) => {
264                    let data = String::from_utf8_lossy(&buffer[..size]).to_string();
265                    terminal.output.lock().await.push_str(&data);
266                    emit_terminal_update(
267                        &notification_tx,
268                        &terminal.session_id,
269                        serde_json::json!({
270                            "sessionUpdate": "terminal_output",
271                            "terminalId": terminal.terminal_id,
272                            "data": data,
273                        }),
274                    );
275                }
276                Err(_) => break,
277            }
278        }
279    });
280}
281
282fn spawn_exit_watcher(terminal: ManagedTerminal, notification_tx: NotificationSender) {
283    tokio::spawn(async move {
284        let code = loop {
285            let maybe_status = {
286                let mut child = terminal.child.lock().await;
287                child.try_wait().ok().flatten()
288            };
289            if let Some(status) = maybe_status {
290                break status.code().unwrap_or(0);
291            }
292            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
293        };
294        *terminal.exit_code.lock().await = Some(code);
295        emit_terminal_update(
296            &notification_tx,
297            &terminal.session_id,
298            serde_json::json!({
299                "sessionUpdate": "terminal_exited",
300                "terminalId": terminal.terminal_id,
301                "exitCode": code,
302            }),
303        );
304    });
305}
306
307#[cfg(test)]
308mod tests {
309    use super::TerminalManager;
310    use tokio::sync::broadcast;
311
312    #[cfg(not(windows))]
313    #[tokio::test]
314    async fn create_write_and_read_terminal_output() {
315        let manager = TerminalManager::default();
316        let (tx, _rx) = broadcast::channel(32);
317
318        let created = manager
319            .create(
320                &serde_json::json!({
321                    "command": "/bin/cat",
322                    "args": [],
323                    "cwd": "/tmp"
324                }),
325                "session-1",
326                &tx,
327            )
328            .await
329            .expect("create terminal");
330        let terminal_id = created["terminalId"]
331            .as_str()
332            .expect("terminal id")
333            .to_string();
334
335        assert!(manager.has_terminal("session-1", &terminal_id).await);
336
337        manager
338            .write(&terminal_id, "hello from terminal\n")
339            .await
340            .expect("write terminal");
341
342        let mut saw_output = false;
343        for _ in 0..20 {
344            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
345            let output = manager.get_output(&terminal_id).await.expect("get output");
346            if output["output"]
347                .as_str()
348                .expect("output string")
349                .contains("hello from terminal")
350            {
351                saw_output = true;
352                break;
353            }
354        }
355        assert!(saw_output, "terminal output should contain echoed input");
356
357        manager.kill(&terminal_id).await.expect("kill terminal");
358        manager.release(&terminal_id).await;
359    }
360
361    #[cfg(not(windows))]
362    #[tokio::test]
363    async fn resize_tracks_terminal_without_failing() {
364        let manager = TerminalManager::default();
365        let (tx, _rx) = broadcast::channel(32);
366
367        let created = manager
368            .create(
369                &serde_json::json!({
370                    "command": "/bin/cat",
371                    "args": [],
372                    "cwd": "/tmp",
373                    "cols": 80,
374                    "rows": 24
375                }),
376                "session-2",
377                &tx,
378            )
379            .await
380            .expect("create terminal");
381        let terminal_id = created["terminalId"]
382            .as_str()
383            .expect("terminal id")
384            .to_string();
385
386        manager
387            .resize(&terminal_id, Some(120), Some(40))
388            .await
389            .expect("resize terminal");
390
391        manager.kill(&terminal_id).await.expect("kill terminal");
392        manager.release(&terminal_id).await;
393    }
394}