agent_tui/terminal/
pty.rs

1use std::collections::HashMap;
2use std::io::Read;
3use std::io::Write;
4use std::os::fd::RawFd;
5use std::sync::Arc;
6use std::sync::Mutex;
7
8use portable_pty::Child;
9use portable_pty::CommandBuilder;
10use portable_pty::MasterPty;
11use portable_pty::PtySize;
12use portable_pty::native_pty_system;
13
14use crate::common::mutex_lock_or_recover;
15
16pub use crate::terminal::error::PtyError;
17
18pub struct PtyHandle {
19    master: Box<dyn MasterPty + Send>,
20    child: Box<dyn Child + Send + Sync>,
21    reader: Arc<Mutex<Box<dyn Read + Send>>>,
22    writer: Arc<Mutex<Box<dyn Write + Send>>>,
23    size: PtySize,
24    reader_fd: RawFd,
25}
26
27impl Drop for PtyHandle {
28    fn drop(&mut self) {
29        if self.is_running() {
30            let _ = self.kill();
31        }
32    }
33}
34
35impl PtyHandle {
36    pub fn spawn(
37        command: &str,
38        args: &[String],
39        cwd: Option<&str>,
40        env: Option<&HashMap<String, String>>,
41        cols: u16,
42        rows: u16,
43    ) -> Result<Self, PtyError> {
44        let pty_system = native_pty_system();
45
46        let size = PtySize {
47            rows,
48            cols,
49            pixel_width: 0,
50            pixel_height: 0,
51        };
52
53        let pair = pty_system
54            .openpty(size)
55            .map_err(|e| PtyError::Open(e.to_string()))?;
56
57        let mut cmd = CommandBuilder::new(command);
58        cmd.args(args);
59
60        if let Some(dir) = cwd {
61            cmd.cwd(dir);
62        }
63
64        if let Some(env_vars) = env {
65            for (key, value) in env_vars {
66                cmd.env(key, value);
67            }
68        }
69
70        cmd.env("TERM", "xterm-256color");
71
72        let child = pair
73            .slave
74            .spawn_command(cmd)
75            .map_err(|e| PtyError::Spawn(e.to_string()))?;
76
77        let reader = pair
78            .master
79            .try_clone_reader()
80            .map_err(|e| PtyError::Open(e.to_string()))?;
81
82        let reader_fd = pair
83            .master
84            .as_raw_fd()
85            .ok_or_else(|| PtyError::Open("Failed to get master fd".to_string()))?;
86
87        let writer = pair
88            .master
89            .take_writer()
90            .map_err(|e| PtyError::Open(e.to_string()))?;
91
92        Ok(Self {
93            master: pair.master,
94            child,
95            reader: Arc::new(Mutex::new(reader)),
96            writer: Arc::new(Mutex::new(writer)),
97            size,
98            reader_fd,
99        })
100    }
101
102    pub fn pid(&self) -> Option<u32> {
103        self.child.process_id()
104    }
105
106    pub fn is_running(&mut self) -> bool {
107        self.child
108            .try_wait()
109            .map(|status| status.is_none())
110            .unwrap_or(false)
111    }
112
113    pub fn write(&self, data: &[u8]) -> Result<(), PtyError> {
114        let mut writer = mutex_lock_or_recover(&self.writer);
115        writer
116            .write_all(data)
117            .map_err(|e| PtyError::Write(e.to_string()))?;
118        writer.flush().map_err(|e| PtyError::Write(e.to_string()))?;
119        Ok(())
120    }
121
122    pub fn write_str(&self, s: &str) -> Result<(), PtyError> {
123        self.write(s.as_bytes())
124    }
125
126    pub fn try_read(&self, buf: &mut [u8], timeout_ms: i32) -> Result<usize, PtyError> {
127        let mut pollfd = libc::pollfd {
128            fd: self.reader_fd,
129            events: libc::POLLIN,
130            revents: 0,
131        };
132
133        // SAFETY: poll() is safe with valid pollfd. fd is from PTY master,
134        // timeout is passed directly, and return value is checked for errors.
135        let result = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
136
137        if result < 0 {
138            return Err(PtyError::Read("poll failed".to_string()));
139        }
140
141        if result == 0 {
142            return Ok(0);
143        }
144
145        let mut reader = mutex_lock_or_recover(&self.reader);
146        reader.read(buf).map_err(|e| PtyError::Read(e.to_string()))
147    }
148
149    pub fn resize(&mut self, cols: u16, rows: u16) -> Result<(), PtyError> {
150        self.size = PtySize {
151            rows,
152            cols,
153            pixel_width: 0,
154            pixel_height: 0,
155        };
156        self.master
157            .resize(self.size)
158            .map_err(|e| PtyError::Resize(e.to_string()))
159    }
160
161    pub fn kill(&mut self) -> Result<(), PtyError> {
162        if !self.is_running() {
163            return Ok(());
164        }
165
166        self.child
167            .kill()
168            .map_err(|e| PtyError::Spawn(e.to_string()))
169    }
170}
171
172pub fn key_to_escape_sequence(key: &str) -> Option<Vec<u8>> {
173    if key.contains('+') {
174        let parts: Vec<&str> = key.split('+').collect();
175        if parts.len() == 2 {
176            let modifier = parts[0];
177            let base_key = parts[1];
178
179            return match modifier.to_lowercase().as_str() {
180                "ctrl" | "control" => {
181                    if base_key.len() == 1 {
182                        let c = base_key.chars().next()?.to_ascii_uppercase();
183                        if c.is_ascii_alphabetic() {
184                            return Some(vec![(c as u8) - b'A' + 1]);
185                        }
186                    }
187
188                    match base_key.to_lowercase().as_str() {
189                        "c" => Some(vec![3]),
190                        "d" => Some(vec![4]),
191                        "z" => Some(vec![26]),
192                        "\\" => Some(vec![28]),
193                        "[" => Some(vec![27]),
194                        _ => None,
195                    }
196                }
197                "alt" | "meta" => {
198                    let base = key_to_escape_sequence(base_key)?;
199                    let mut result = vec![0x1b];
200                    result.extend(base);
201                    Some(result)
202                }
203                "shift" => match base_key.to_lowercase().as_str() {
204                    "tab" => Some(vec![0x1b, b'[', b'Z']),
205                    _ => {
206                        if base_key.len() == 1 {
207                            Some(base_key.to_uppercase().as_bytes().to_vec())
208                        } else {
209                            None
210                        }
211                    }
212                },
213                _ => None,
214            };
215        }
216    }
217
218    match key {
219        "Enter" | "Return" => Some(vec![b'\r']),
220        "Tab" => Some(vec![b'\t']),
221        "Escape" | "Esc" => Some(vec![0x1b]),
222        "Backspace" => Some(vec![0x7f]),
223        "Delete" => Some(vec![0x1b, b'[', b'3', b'~']),
224        "Space" => Some(vec![b' ']),
225
226        "ArrowUp" | "Up" => Some(vec![0x1b, b'[', b'A']),
227        "ArrowDown" | "Down" => Some(vec![0x1b, b'[', b'B']),
228        "ArrowRight" | "Right" => Some(vec![0x1b, b'[', b'C']),
229        "ArrowLeft" | "Left" => Some(vec![0x1b, b'[', b'D']),
230
231        "Home" => Some(vec![0x1b, b'[', b'H']),
232        "End" => Some(vec![0x1b, b'[', b'F']),
233        "PageUp" => Some(vec![0x1b, b'[', b'5', b'~']),
234        "PageDown" => Some(vec![0x1b, b'[', b'6', b'~']),
235        "Insert" => Some(vec![0x1b, b'[', b'2', b'~']),
236
237        "F1" => Some(vec![0x1b, b'O', b'P']),
238        "F2" => Some(vec![0x1b, b'O', b'Q']),
239        "F3" => Some(vec![0x1b, b'O', b'R']),
240        "F4" => Some(vec![0x1b, b'O', b'S']),
241        "F5" => Some(vec![0x1b, b'[', b'1', b'5', b'~']),
242        "F6" => Some(vec![0x1b, b'[', b'1', b'7', b'~']),
243        "F7" => Some(vec![0x1b, b'[', b'1', b'8', b'~']),
244        "F8" => Some(vec![0x1b, b'[', b'1', b'9', b'~']),
245        "F9" => Some(vec![0x1b, b'[', b'2', b'0', b'~']),
246        "F10" => Some(vec![0x1b, b'[', b'2', b'1', b'~']),
247        "F11" => Some(vec![0x1b, b'[', b'2', b'3', b'~']),
248        "F12" => Some(vec![0x1b, b'[', b'2', b'4', b'~']),
249
250        _ if key.len() == 1 => Some(key.as_bytes().to_vec()),
251
252        _ => None,
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_key_to_escape_sequence() {
262        assert_eq!(key_to_escape_sequence("Enter"), Some(vec![b'\r']));
263        assert_eq!(key_to_escape_sequence("Tab"), Some(vec![b'\t']));
264        assert_eq!(key_to_escape_sequence("Escape"), Some(vec![0x1b]));
265        assert_eq!(
266            key_to_escape_sequence("ArrowUp"),
267            Some(vec![0x1b, b'[', b'A'])
268        );
269        assert_eq!(key_to_escape_sequence("Ctrl+C"), Some(vec![3]));
270        assert_eq!(key_to_escape_sequence("a"), Some(vec![b'a']));
271    }
272}