1use crate::sync_utils::mutex_lock_or_recover;
2use portable_pty::{native_pty_system, Child, CommandBuilder, MasterPty, PtySize};
3use std::io::{Read, Write};
4use std::os::fd::RawFd;
5use std::sync::{Arc, Mutex};
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum PtyError {
10 #[error("Failed to open PTY: {0}")]
11 Open(String),
12 #[error("Failed to spawn process: {0}")]
13 Spawn(String),
14 #[error("Failed to write to PTY: {0}")]
15 Write(String),
16 #[error("Failed to read from PTY: {0}")]
17 Read(String),
18 #[error("Failed to resize PTY: {0}")]
19 Resize(String),
20}
21
22pub struct PtyHandle {
23 master: Box<dyn MasterPty + Send>,
24 child: Box<dyn Child + Send + Sync>,
25 reader: Arc<Mutex<Box<dyn Read + Send>>>,
26 writer: Arc<Mutex<Box<dyn Write + Send>>>,
27 size: PtySize,
28 reader_fd: RawFd,
29}
30
31impl PtyHandle {
32 pub fn spawn(
33 command: &str,
34 args: &[String],
35 cwd: Option<&str>,
36 env: Option<&std::collections::HashMap<String, String>>,
37 cols: u16,
38 rows: u16,
39 ) -> Result<Self, PtyError> {
40 let pty_system = native_pty_system();
41
42 let size = PtySize {
43 rows,
44 cols,
45 pixel_width: 0,
46 pixel_height: 0,
47 };
48
49 let pair = pty_system
50 .openpty(size)
51 .map_err(|e| PtyError::Open(e.to_string()))?;
52
53 let mut cmd = CommandBuilder::new(command);
54 cmd.args(args);
55
56 if let Some(dir) = cwd {
57 cmd.cwd(dir);
58 }
59
60 if let Some(env_vars) = env {
61 for (key, value) in env_vars {
62 cmd.env(key, value);
63 }
64 }
65
66 cmd.env("TERM", "xterm-256color");
67
68 let child = pair
69 .slave
70 .spawn_command(cmd)
71 .map_err(|e| PtyError::Spawn(e.to_string()))?;
72
73 let reader = pair
74 .master
75 .try_clone_reader()
76 .map_err(|e| PtyError::Open(e.to_string()))?;
77
78 let reader_fd = pair
79 .master
80 .as_raw_fd()
81 .ok_or_else(|| PtyError::Open("Failed to get master fd".to_string()))?;
82
83 let writer = pair
84 .master
85 .take_writer()
86 .map_err(|e| PtyError::Open(e.to_string()))?;
87
88 Ok(Self {
89 master: pair.master,
90 child,
91 reader: Arc::new(Mutex::new(reader)),
92 writer: Arc::new(Mutex::new(writer)),
93 size,
94 reader_fd,
95 })
96 }
97
98 pub fn pid(&self) -> Option<u32> {
99 self.child.process_id()
100 }
101
102 pub fn is_running(&mut self) -> bool {
103 self.child
104 .try_wait()
105 .map(|status| status.is_none())
106 .unwrap_or(false)
107 }
108
109 pub fn write(&self, data: &[u8]) -> Result<(), PtyError> {
110 let mut writer = mutex_lock_or_recover(&self.writer);
111 writer
112 .write_all(data)
113 .map_err(|e| PtyError::Write(e.to_string()))?;
114 writer.flush().map_err(|e| PtyError::Write(e.to_string()))?;
115 Ok(())
116 }
117
118 pub fn write_str(&self, s: &str) -> Result<(), PtyError> {
119 self.write(s.as_bytes())
120 }
121
122 pub fn try_read(&self, buf: &mut [u8], timeout_ms: i32) -> Result<usize, PtyError> {
123 let mut pollfd = libc::pollfd {
124 fd: self.reader_fd,
125 events: libc::POLLIN,
126 revents: 0,
127 };
128
129 let result = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
130
131 if result < 0 {
132 return Err(PtyError::Read("poll failed".to_string()));
133 }
134
135 if result == 0 {
136 return Ok(0);
137 }
138
139 let mut reader = mutex_lock_or_recover(&self.reader);
140 reader.read(buf).map_err(|e| PtyError::Read(e.to_string()))
141 }
142
143 pub fn resize(&mut self, cols: u16, rows: u16) -> Result<(), PtyError> {
144 self.size = PtySize {
145 rows,
146 cols,
147 pixel_width: 0,
148 pixel_height: 0,
149 };
150 self.master
151 .resize(self.size)
152 .map_err(|e| PtyError::Resize(e.to_string()))
153 }
154
155 pub fn kill(&mut self) -> Result<(), PtyError> {
156 if !self.is_running() {
158 return Ok(());
159 }
160
161 self.child
162 .kill()
163 .map_err(|e| PtyError::Spawn(e.to_string()))
164 }
165}
166
167pub fn key_to_escape_sequence(key: &str) -> Option<Vec<u8>> {
168 if key.contains('+') {
169 let parts: Vec<&str> = key.split('+').collect();
170 if parts.len() == 2 {
171 let modifier = parts[0];
172 let base_key = parts[1];
173
174 return match modifier.to_lowercase().as_str() {
175 "ctrl" | "control" => {
176 if base_key.len() == 1 {
177 let c = base_key.chars().next()?.to_ascii_uppercase();
178 if c.is_ascii_alphabetic() {
179 return Some(vec![(c as u8) - b'A' + 1]);
180 }
181 }
182
183 match base_key.to_lowercase().as_str() {
184 "c" => Some(vec![3]),
185 "d" => Some(vec![4]),
186 "z" => Some(vec![26]),
187 "\\" => Some(vec![28]),
188 "[" => Some(vec![27]),
189 _ => None,
190 }
191 }
192 "alt" | "meta" => {
193 let base = key_to_escape_sequence(base_key)?;
194 let mut result = vec![0x1b];
195 result.extend(base);
196 Some(result)
197 }
198 "shift" => match base_key.to_lowercase().as_str() {
199 "tab" => Some(vec![0x1b, b'[', b'Z']),
200 _ => {
201 if base_key.len() == 1 {
202 Some(base_key.to_uppercase().as_bytes().to_vec())
203 } else {
204 None
205 }
206 }
207 },
208 _ => None,
209 };
210 }
211 }
212
213 match key {
214 "Enter" | "Return" => Some(vec![b'\r']),
215 "Tab" => Some(vec![b'\t']),
216 "Escape" | "Esc" => Some(vec![0x1b]),
217 "Backspace" => Some(vec![0x7f]),
218 "Delete" => Some(vec![0x1b, b'[', b'3', b'~']),
219 "Space" => Some(vec![b' ']),
220
221 "ArrowUp" | "Up" => Some(vec![0x1b, b'[', b'A']),
222 "ArrowDown" | "Down" => Some(vec![0x1b, b'[', b'B']),
223 "ArrowRight" | "Right" => Some(vec![0x1b, b'[', b'C']),
224 "ArrowLeft" | "Left" => Some(vec![0x1b, b'[', b'D']),
225
226 "Home" => Some(vec![0x1b, b'[', b'H']),
227 "End" => Some(vec![0x1b, b'[', b'F']),
228 "PageUp" => Some(vec![0x1b, b'[', b'5', b'~']),
229 "PageDown" => Some(vec![0x1b, b'[', b'6', b'~']),
230 "Insert" => Some(vec![0x1b, b'[', b'2', b'~']),
231
232 "F1" => Some(vec![0x1b, b'O', b'P']),
233 "F2" => Some(vec![0x1b, b'O', b'Q']),
234 "F3" => Some(vec![0x1b, b'O', b'R']),
235 "F4" => Some(vec![0x1b, b'O', b'S']),
236 "F5" => Some(vec![0x1b, b'[', b'1', b'5', b'~']),
237 "F6" => Some(vec![0x1b, b'[', b'1', b'7', b'~']),
238 "F7" => Some(vec![0x1b, b'[', b'1', b'8', b'~']),
239 "F8" => Some(vec![0x1b, b'[', b'1', b'9', b'~']),
240 "F9" => Some(vec![0x1b, b'[', b'2', b'0', b'~']),
241 "F10" => Some(vec![0x1b, b'[', b'2', b'1', b'~']),
242 "F11" => Some(vec![0x1b, b'[', b'2', b'3', b'~']),
243 "F12" => Some(vec![0x1b, b'[', b'2', b'4', b'~']),
244
245 _ if key.len() == 1 => Some(key.as_bytes().to_vec()),
246
247 _ => None,
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_key_to_escape_sequence() {
257 assert_eq!(key_to_escape_sequence("Enter"), Some(vec![b'\r']));
258 assert_eq!(key_to_escape_sequence("Tab"), Some(vec![b'\t']));
259 assert_eq!(key_to_escape_sequence("Escape"), Some(vec![0x1b]));
260 assert_eq!(
261 key_to_escape_sequence("ArrowUp"),
262 Some(vec![0x1b, b'[', b'A'])
263 );
264 assert_eq!(key_to_escape_sequence("Ctrl+C"), Some(vec![3]));
265 assert_eq!(key_to_escape_sequence("a"), Some(vec![b'a']));
266 }
267}