wezterm_ssh/
pty.rs

1use crate::session::{SessionRequest, SessionSender, SignalChannel};
2use crate::sessioninner::{ChannelId, ChannelInfo, DescriptorState};
3use crate::sessionwrap::SessionWrap;
4use filedescriptor::{socketpair, FileDescriptor};
5use portable_pty::{ExitStatus, PtySize};
6use smol::channel::{bounded, Receiver, TryRecvError};
7use std::collections::{HashMap, VecDeque};
8use std::io::{Read, Write};
9use std::sync::Mutex;
10
11#[derive(Debug)]
12pub(crate) struct NewPty {
13    pub term: String,
14    pub size: PtySize,
15    pub command_line: Option<String>,
16    pub env: Option<HashMap<String, String>>,
17}
18
19#[derive(Debug)]
20pub(crate) struct ResizePty {
21    pub channel: ChannelId,
22    pub size: PtySize,
23}
24
25#[derive(Debug)]
26pub struct SshPty {
27    pub(crate) channel: ChannelId,
28    pub(crate) tx: Option<SessionSender>,
29    pub(crate) reader: FileDescriptor,
30    pub(crate) writer: FileDescriptor,
31    pub(crate) size: Mutex<PtySize>,
32}
33
34impl std::io::Write for SshPty {
35    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
36        self.writer.write(buf)
37    }
38
39    fn flush(&mut self) -> std::io::Result<()> {
40        self.writer.flush()
41    }
42}
43
44impl portable_pty::MasterPty for SshPty {
45    fn resize(&self, size: PtySize) -> anyhow::Result<()> {
46        self.tx
47            .as_ref()
48            .unwrap()
49            .try_send(SessionRequest::ResizePty(
50                ResizePty {
51                    channel: self.channel,
52                    size,
53                },
54                None,
55            ))?;
56
57        *self.size.lock().unwrap() = size;
58        Ok(())
59    }
60
61    fn get_size(&self) -> anyhow::Result<PtySize> {
62        Ok(*self.size.lock().unwrap())
63    }
64
65    fn try_clone_reader(&self) -> anyhow::Result<Box<(dyn Read + Send + 'static)>> {
66        let reader = self.reader.try_clone()?;
67        Ok(Box::new(reader))
68    }
69
70    fn try_clone_writer(&self) -> anyhow::Result<Box<(dyn Write + Send + 'static)>> {
71        let writer = self.writer.try_clone()?;
72        Ok(Box::new(writer))
73    }
74
75    #[cfg(unix)]
76    fn process_group_leader(&self) -> Option<i32> {
77        // It's not local, so there's no meaningful leader
78        None
79    }
80}
81
82#[derive(Debug)]
83pub struct SshChildProcess {
84    pub(crate) channel: ChannelId,
85    pub(crate) tx: Option<SessionSender>,
86    pub(crate) exit: Receiver<ExitStatus>,
87    pub(crate) exited: Option<ExitStatus>,
88}
89
90impl SshChildProcess {
91    pub async fn async_wait(&mut self) -> std::io::Result<ExitStatus> {
92        if let Some(status) = self.exited.as_ref() {
93            return Ok(status.clone());
94        }
95        match self.exit.recv().await {
96            Ok(status) => {
97                self.exited.replace(status.clone());
98                Ok(status)
99            }
100            Err(_) => {
101                let status = ExitStatus::with_exit_code(1);
102                self.exited.replace(status.clone());
103                Ok(status)
104            }
105        }
106    }
107}
108
109impl portable_pty::Child for SshChildProcess {
110    fn try_wait(&mut self) -> std::io::Result<Option<ExitStatus>> {
111        if let Some(status) = self.exited.as_ref() {
112            return Ok(Some(status.clone()));
113        }
114        match self.exit.try_recv() {
115            Ok(status) => {
116                self.exited.replace(status.clone());
117                Ok(Some(status))
118            }
119            Err(TryRecvError::Empty) => Ok(None),
120            Err(TryRecvError::Closed) => {
121                let status = ExitStatus::with_exit_code(1);
122                self.exited.replace(status.clone());
123                Ok(Some(status))
124            }
125        }
126    }
127
128    fn wait(&mut self) -> std::io::Result<portable_pty::ExitStatus> {
129        if let Some(status) = self.exited.as_ref() {
130            return Ok(status.clone());
131        }
132        match smol::block_on(self.exit.recv()) {
133            Ok(status) => {
134                self.exited.replace(status.clone());
135                Ok(status)
136            }
137            Err(_) => {
138                let status = ExitStatus::with_exit_code(1);
139                self.exited.replace(status.clone());
140                Ok(status)
141            }
142        }
143    }
144
145    fn process_id(&self) -> Option<u32> {
146        None
147    }
148
149    #[cfg(windows)]
150    fn as_raw_handle(&self) -> Option<std::os::windows::io::RawHandle> {
151        None
152    }
153}
154
155impl portable_pty::ChildKiller for SshChildProcess {
156    fn kill(&mut self) -> std::io::Result<()> {
157        if let Some(tx) = self.tx.as_ref() {
158            tx.try_send(SessionRequest::SignalChannel(SignalChannel {
159                channel: self.channel,
160                signame: "HUP",
161            }))
162            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
163        }
164        Ok(())
165    }
166
167    fn clone_killer(&self) -> Box<dyn portable_pty::ChildKiller + Send + Sync> {
168        Box::new(SshChildKiller {
169            tx: self.tx.clone(),
170            channel: self.channel,
171        })
172    }
173}
174
175#[derive(Debug, Clone)]
176struct SshChildKiller {
177    pub(crate) tx: Option<SessionSender>,
178    pub(crate) channel: ChannelId,
179}
180
181impl portable_pty::ChildKiller for SshChildKiller {
182    fn kill(&mut self) -> std::io::Result<()> {
183        if let Some(tx) = self.tx.as_ref() {
184            tx.try_send(SessionRequest::SignalChannel(SignalChannel {
185                channel: self.channel,
186                signame: "HUP",
187            }))
188            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
189        }
190        Ok(())
191    }
192
193    fn clone_killer(&self) -> Box<dyn portable_pty::ChildKiller + Send + Sync> {
194        Box::new(SshChildKiller {
195            tx: self.tx.clone(),
196            channel: self.channel,
197        })
198    }
199}
200
201impl crate::sessioninner::SessionInner {
202    pub fn new_pty(
203        &mut self,
204        sess: &mut SessionWrap,
205        newpty: NewPty,
206    ) -> anyhow::Result<(SshPty, SshChildProcess)> {
207        sess.set_blocking(true);
208
209        let mut channel = sess.open_session()?;
210
211        /* libssh2 doesn't properly support agent forwarding
212         * at this time:
213         * <https://github.com/libssh2/libssh2/issues/535>
214        if let Some("yes") = self.config.get("forwardagent").map(|s| s.as_str()) {
215            log::info!("requesting agent forwarding");
216            if let Err(err) = channel.request_auth_agent_forwarding() {
217                log::error!("Failed to establish agent forwarding: {:#}", err);
218            }
219            log::info!("agent forwarding OK!");
220        }
221        */
222
223        channel.request_pty(&newpty)?;
224
225        if let Some(env) = &newpty.env {
226            for (key, val) in env {
227                if let Err(err) = channel.request_env(key, val) {
228                    // Depending on the server configuration, a given
229                    // setenv request may not succeed, but that doesn't
230                    // prevent the connection from being set up.
231                    log::warn!(
232                        "ssh: setenv {}={} failed: {}. \
233                         Check the AcceptEnv setting on the ssh server side.",
234                        key,
235                        val,
236                        err
237                    );
238                }
239            }
240        }
241
242        if let Some(cmd) = &newpty.command_line {
243            channel.request_exec(cmd)?;
244        } else {
245            channel.request_shell()?;
246        }
247
248        let channel_id = self.next_channel_id;
249        self.next_channel_id += 1;
250
251        let (write_to_stdin, mut read_from_stdin) = socketpair()?;
252        let (mut write_to_stdout, read_from_stdout) = socketpair()?;
253        let write_to_stderr = write_to_stdout.try_clone()?;
254
255        read_from_stdin.set_non_blocking(true)?;
256        write_to_stdout.set_non_blocking(true)?;
257
258        let ssh_pty = SshPty {
259            channel: channel_id,
260            tx: None,
261            reader: read_from_stdout,
262            writer: write_to_stdin,
263            size: Mutex::new(newpty.size),
264        };
265
266        let (exit_tx, exit_rx) = bounded(1);
267
268        let child = SshChildProcess {
269            channel: channel_id,
270            tx: None,
271            exit: exit_rx,
272            exited: None,
273        };
274
275        let info = ChannelInfo {
276            channel_id,
277            channel,
278            exit: Some(exit_tx),
279            descriptors: [
280                DescriptorState {
281                    fd: Some(read_from_stdin),
282                    buf: VecDeque::with_capacity(8192),
283                },
284                DescriptorState {
285                    fd: Some(write_to_stdout),
286                    buf: VecDeque::with_capacity(8192),
287                },
288                DescriptorState {
289                    fd: Some(write_to_stderr),
290                    buf: VecDeque::with_capacity(8192),
291                },
292            ],
293        };
294
295        self.channels.insert(channel_id, info);
296
297        Ok((ssh_pty, child))
298    }
299
300    pub fn resize_pty(&mut self, resize: ResizePty) -> anyhow::Result<()> {
301        let info = self
302            .channels
303            .get_mut(&resize.channel)
304            .ok_or_else(|| anyhow::anyhow!("invalid channel id {}", resize.channel))?;
305        info.channel.resize_pty(&resize)?;
306        Ok(())
307    }
308}