agent_tui/
pty.rs

1use crate::sync_utils::mutex_lock_or_recover;
2use portable_pty::{native_pty_system, Child, CommandBuilder, MasterPty, PtySize};
3use std::io::{Read, Write};
4use std::os::fd::RawFd;
5use std::sync::{Arc, Mutex};
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum PtyError {
10    #[error("Failed to open PTY: {0}")]
11    Open(String),
12    #[error("Failed to spawn process: {0}")]
13    Spawn(String),
14    #[error("Failed to write to PTY: {0}")]
15    Write(String),
16    #[error("Failed to read from PTY: {0}")]
17    Read(String),
18    #[error("Failed to resize PTY: {0}")]
19    Resize(String),
20}
21
22pub struct PtyHandle {
23    master: Box<dyn MasterPty + Send>,
24    child: Box<dyn Child + Send + Sync>,
25    reader: Arc<Mutex<Box<dyn Read + Send>>>,
26    writer: Arc<Mutex<Box<dyn Write + Send>>>,
27    size: PtySize,
28    reader_fd: RawFd,
29}
30
31impl PtyHandle {
32    pub fn spawn(
33        command: &str,
34        args: &[String],
35        cwd: Option<&str>,
36        env: Option<&std::collections::HashMap<String, String>>,
37        cols: u16,
38        rows: u16,
39    ) -> Result<Self, PtyError> {
40        let pty_system = native_pty_system();
41
42        let size = PtySize {
43            rows,
44            cols,
45            pixel_width: 0,
46            pixel_height: 0,
47        };
48
49        let pair = pty_system
50            .openpty(size)
51            .map_err(|e| PtyError::Open(e.to_string()))?;
52
53        let mut cmd = CommandBuilder::new(command);
54        cmd.args(args);
55
56        if let Some(dir) = cwd {
57            cmd.cwd(dir);
58        }
59
60        if let Some(env_vars) = env {
61            for (key, value) in env_vars {
62                cmd.env(key, value);
63            }
64        }
65
66        cmd.env("TERM", "xterm-256color");
67
68        let child = pair
69            .slave
70            .spawn_command(cmd)
71            .map_err(|e| PtyError::Spawn(e.to_string()))?;
72
73        let reader = pair
74            .master
75            .try_clone_reader()
76            .map_err(|e| PtyError::Open(e.to_string()))?;
77
78        let reader_fd = pair
79            .master
80            .as_raw_fd()
81            .ok_or_else(|| PtyError::Open("Failed to get master fd".to_string()))?;
82
83        let writer = pair
84            .master
85            .take_writer()
86            .map_err(|e| PtyError::Open(e.to_string()))?;
87
88        Ok(Self {
89            master: pair.master,
90            child,
91            reader: Arc::new(Mutex::new(reader)),
92            writer: Arc::new(Mutex::new(writer)),
93            size,
94            reader_fd,
95        })
96    }
97
98    pub fn pid(&self) -> Option<u32> {
99        self.child.process_id()
100    }
101
102    pub fn is_running(&mut self) -> bool {
103        self.child
104            .try_wait()
105            .map(|status| status.is_none())
106            .unwrap_or(false)
107    }
108
109    pub fn write(&self, data: &[u8]) -> Result<(), PtyError> {
110        let mut writer = mutex_lock_or_recover(&self.writer);
111        writer
112            .write_all(data)
113            .map_err(|e| PtyError::Write(e.to_string()))?;
114        writer.flush().map_err(|e| PtyError::Write(e.to_string()))?;
115        Ok(())
116    }
117
118    pub fn write_str(&self, s: &str) -> Result<(), PtyError> {
119        self.write(s.as_bytes())
120    }
121
122    pub fn try_read(&self, buf: &mut [u8], timeout_ms: i32) -> Result<usize, PtyError> {
123        let mut pollfd = libc::pollfd {
124            fd: self.reader_fd,
125            events: libc::POLLIN,
126            revents: 0,
127        };
128
129        let result = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
130
131        if result < 0 {
132            return Err(PtyError::Read("poll failed".to_string()));
133        }
134
135        if result == 0 {
136            return Ok(0);
137        }
138
139        let mut reader = mutex_lock_or_recover(&self.reader);
140        reader.read(buf).map_err(|e| PtyError::Read(e.to_string()))
141    }
142
143    pub fn resize(&mut self, cols: u16, rows: u16) -> Result<(), PtyError> {
144        self.size = PtySize {
145            rows,
146            cols,
147            pixel_width: 0,
148            pixel_height: 0,
149        };
150        self.master
151            .resize(self.size)
152            .map_err(|e| PtyError::Resize(e.to_string()))
153    }
154
155    pub fn kill(&mut self) -> Result<(), PtyError> {
156        // If the process is already stopped, the desired state is achieved
157        if !self.is_running() {
158            return Ok(());
159        }
160
161        self.child
162            .kill()
163            .map_err(|e| PtyError::Spawn(e.to_string()))
164    }
165}
166
167pub fn key_to_escape_sequence(key: &str) -> Option<Vec<u8>> {
168    if key.contains('+') {
169        let parts: Vec<&str> = key.split('+').collect();
170        if parts.len() == 2 {
171            let modifier = parts[0];
172            let base_key = parts[1];
173
174            return match modifier.to_lowercase().as_str() {
175                "ctrl" | "control" => {
176                    if base_key.len() == 1 {
177                        let c = base_key.chars().next()?.to_ascii_uppercase();
178                        if c.is_ascii_alphabetic() {
179                            return Some(vec![(c as u8) - b'A' + 1]);
180                        }
181                    }
182
183                    match base_key.to_lowercase().as_str() {
184                        "c" => Some(vec![3]),
185                        "d" => Some(vec![4]),
186                        "z" => Some(vec![26]),
187                        "\\" => Some(vec![28]),
188                        "[" => Some(vec![27]),
189                        _ => None,
190                    }
191                }
192                "alt" | "meta" => {
193                    let base = key_to_escape_sequence(base_key)?;
194                    let mut result = vec![0x1b];
195                    result.extend(base);
196                    Some(result)
197                }
198                "shift" => match base_key.to_lowercase().as_str() {
199                    "tab" => Some(vec![0x1b, b'[', b'Z']),
200                    _ => {
201                        if base_key.len() == 1 {
202                            Some(base_key.to_uppercase().as_bytes().to_vec())
203                        } else {
204                            None
205                        }
206                    }
207                },
208                _ => None,
209            };
210        }
211    }
212
213    match key {
214        "Enter" | "Return" => Some(vec![b'\r']),
215        "Tab" => Some(vec![b'\t']),
216        "Escape" | "Esc" => Some(vec![0x1b]),
217        "Backspace" => Some(vec![0x7f]),
218        "Delete" => Some(vec![0x1b, b'[', b'3', b'~']),
219        "Space" => Some(vec![b' ']),
220
221        "ArrowUp" | "Up" => Some(vec![0x1b, b'[', b'A']),
222        "ArrowDown" | "Down" => Some(vec![0x1b, b'[', b'B']),
223        "ArrowRight" | "Right" => Some(vec![0x1b, b'[', b'C']),
224        "ArrowLeft" | "Left" => Some(vec![0x1b, b'[', b'D']),
225
226        "Home" => Some(vec![0x1b, b'[', b'H']),
227        "End" => Some(vec![0x1b, b'[', b'F']),
228        "PageUp" => Some(vec![0x1b, b'[', b'5', b'~']),
229        "PageDown" => Some(vec![0x1b, b'[', b'6', b'~']),
230        "Insert" => Some(vec![0x1b, b'[', b'2', b'~']),
231
232        "F1" => Some(vec![0x1b, b'O', b'P']),
233        "F2" => Some(vec![0x1b, b'O', b'Q']),
234        "F3" => Some(vec![0x1b, b'O', b'R']),
235        "F4" => Some(vec![0x1b, b'O', b'S']),
236        "F5" => Some(vec![0x1b, b'[', b'1', b'5', b'~']),
237        "F6" => Some(vec![0x1b, b'[', b'1', b'7', b'~']),
238        "F7" => Some(vec![0x1b, b'[', b'1', b'8', b'~']),
239        "F8" => Some(vec![0x1b, b'[', b'1', b'9', b'~']),
240        "F9" => Some(vec![0x1b, b'[', b'2', b'0', b'~']),
241        "F10" => Some(vec![0x1b, b'[', b'2', b'1', b'~']),
242        "F11" => Some(vec![0x1b, b'[', b'2', b'3', b'~']),
243        "F12" => Some(vec![0x1b, b'[', b'2', b'4', b'~']),
244
245        _ if key.len() == 1 => Some(key.as_bytes().to_vec()),
246
247        _ => None,
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_key_to_escape_sequence() {
257        assert_eq!(key_to_escape_sequence("Enter"), Some(vec![b'\r']));
258        assert_eq!(key_to_escape_sequence("Tab"), Some(vec![b'\t']));
259        assert_eq!(key_to_escape_sequence("Escape"), Some(vec![0x1b]));
260        assert_eq!(
261            key_to_escape_sequence("ArrowUp"),
262            Some(vec![0x1b, b'[', b'A'])
263        );
264        assert_eq!(key_to_escape_sequence("Ctrl+C"), Some(vec![3]));
265        assert_eq!(key_to_escape_sequence("a"), Some(vec![b'a']));
266    }
267}