pty_process/
pty.rs

1#![allow(clippy::module_name_repetitions)]
2
3use std::io::Write as _;
4
5type AsyncPty = tokio::io::unix::AsyncFd<crate::sys::Pty>;
6
7/// Allocate and return a new pty and pts.
8///
9/// # Errors
10/// Returns an error if the pty failed to be allocated, or if we were
11/// unable to put it into non-blocking mode.
12pub fn open() -> crate::Result<(Pty, Pts)> {
13    let pty = crate::sys::Pty::open()?;
14    let pts = pty.pts()?;
15    pty.set_nonblocking()?;
16    let pty = tokio::io::unix::AsyncFd::new(pty)?;
17    Ok((Pty(pty), Pts(pts)))
18}
19
20/// An allocated pty
21pub struct Pty(AsyncPty);
22
23impl Pty {
24    /// Use the provided file descriptor as a pty.
25    ///
26    /// # Safety
27    /// The provided file descriptor must be valid, open, belong to a pty,
28    /// and put into nonblocking mode.
29    ///
30    /// # Errors
31    /// Returns an error if it fails to be registered with the async runtime.
32    pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> crate::Result<Self> {
33        Ok(Self(tokio::io::unix::AsyncFd::new(unsafe {
34            crate::sys::Pty::from_fd(fd)
35        })?))
36    }
37
38    /// Change the terminal size associated with the pty.
39    ///
40    /// # Errors
41    /// Returns an error if we were unable to set the terminal size.
42    pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
43        self.0.get_ref().set_term_size(size)
44    }
45
46    /// Splits a `Pty` into a read half and a write half, which can be used to
47    /// read from and write to the pty concurrently. Does not allocate, but
48    /// the returned halves cannot be moved to independent tasks.
49    pub fn split(&mut self) -> (ReadPty<'_>, WritePty<'_>) {
50        (ReadPty(&self.0), WritePty(&self.0))
51    }
52
53    /// Splits a `Pty` into a read half and a write half, which can be used to
54    /// read from and write to the pty concurrently. This method requires an
55    /// allocation, but the returned halves can be moved to independent tasks.
56    /// The original `Pty` instance can be recovered via the
57    /// [`OwnedReadPty::unsplit`] method.
58    #[must_use]
59    pub fn into_split(self) -> (OwnedReadPty, OwnedWritePty) {
60        let Self(pt) = self;
61        let read_pt = std::sync::Arc::new(pt);
62        let write_pt = std::sync::Arc::clone(&read_pt);
63        (OwnedReadPty(read_pt), OwnedWritePty(write_pt))
64    }
65}
66
67impl From<Pty> for std::os::fd::OwnedFd {
68    fn from(pty: Pty) -> Self {
69        pty.0.into_inner().into()
70    }
71}
72
73impl std::os::fd::AsFd for Pty {
74    fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
75        self.0.get_ref().as_fd()
76    }
77}
78
79impl std::os::fd::AsRawFd for Pty {
80    fn as_raw_fd(&self) -> std::os::fd::RawFd {
81        self.0.get_ref().as_raw_fd()
82    }
83}
84
85impl tokio::io::AsyncRead for Pty {
86    fn poll_read(
87        self: std::pin::Pin<&mut Self>,
88        cx: &mut std::task::Context<'_>,
89        buf: &mut tokio::io::ReadBuf,
90    ) -> std::task::Poll<std::io::Result<()>> {
91        poll_read(&self.0, cx, buf)
92    }
93}
94
95impl tokio::io::AsyncWrite for Pty {
96    fn poll_write(
97        self: std::pin::Pin<&mut Self>,
98        cx: &mut std::task::Context<'_>,
99        buf: &[u8],
100    ) -> std::task::Poll<std::io::Result<usize>> {
101        poll_write(&self.0, cx, buf)
102    }
103
104    fn poll_flush(
105        self: std::pin::Pin<&mut Self>,
106        cx: &mut std::task::Context<'_>,
107    ) -> std::task::Poll<std::io::Result<()>> {
108        poll_flush(&self.0, cx)
109    }
110
111    fn poll_shutdown(
112        self: std::pin::Pin<&mut Self>,
113        _cx: &mut std::task::Context<'_>,
114    ) -> std::task::Poll<Result<(), std::io::Error>> {
115        std::task::Poll::Ready(Ok(()))
116    }
117}
118
119/// The child end of the pty
120///
121/// See [`open`] and [`Command::spawn`](crate::Command::spawn)
122pub struct Pts(pub(crate) crate::sys::Pts);
123
124impl Pts {
125    /// Use the provided file descriptor as a pts.
126    ///
127    /// # Safety
128    /// The provided file descriptor must be valid, open, and belong to the
129    /// child end of a pty.
130    #[must_use]
131    pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> Self {
132        Self(unsafe { crate::sys::Pts::from_fd(fd) })
133    }
134}
135
136impl std::os::fd::AsFd for Pts {
137    fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
138        self.0.as_fd()
139    }
140}
141
142impl std::os::fd::AsRawFd for Pts {
143    fn as_raw_fd(&self) -> std::os::fd::RawFd {
144        self.0.as_raw_fd()
145    }
146}
147
148/// Borrowed read half of a [`Pty`]
149pub struct ReadPty<'a>(&'a AsyncPty);
150
151impl tokio::io::AsyncRead for ReadPty<'_> {
152    fn poll_read(
153        self: std::pin::Pin<&mut Self>,
154        cx: &mut std::task::Context<'_>,
155        buf: &mut tokio::io::ReadBuf,
156    ) -> std::task::Poll<std::io::Result<()>> {
157        poll_read(self.0, cx, buf)
158    }
159}
160
161/// Borrowed write half of a [`Pty`]
162pub struct WritePty<'a>(&'a AsyncPty);
163
164impl WritePty<'_> {
165    /// Change the terminal size associated with the pty.
166    ///
167    /// # Errors
168    /// Returns an error if we were unable to set the terminal size.
169    pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
170        self.0.get_ref().set_term_size(size)
171    }
172}
173
174impl tokio::io::AsyncWrite for WritePty<'_> {
175    fn poll_write(
176        self: std::pin::Pin<&mut Self>,
177        cx: &mut std::task::Context<'_>,
178        buf: &[u8],
179    ) -> std::task::Poll<std::io::Result<usize>> {
180        poll_write(self.0, cx, buf)
181    }
182
183    fn poll_flush(
184        self: std::pin::Pin<&mut Self>,
185        cx: &mut std::task::Context<'_>,
186    ) -> std::task::Poll<std::io::Result<()>> {
187        poll_flush(self.0, cx)
188    }
189
190    fn poll_shutdown(
191        self: std::pin::Pin<&mut Self>,
192        _cx: &mut std::task::Context<'_>,
193    ) -> std::task::Poll<Result<(), std::io::Error>> {
194        std::task::Poll::Ready(Ok(()))
195    }
196}
197
198/// Owned read half of a [`Pty`]
199#[derive(Debug)]
200pub struct OwnedReadPty(std::sync::Arc<AsyncPty>);
201
202impl OwnedReadPty {
203    /// Attempt to join the two halves of a `Pty` back into a single instance.
204    /// The two halves must have originated from calling
205    /// [`into_split`](Pty::into_split) on a single instance.
206    ///
207    /// # Errors
208    /// Returns an error if the two halves came from different [`Pty`]
209    /// instances. The mismatched halves are returned as part of the error.
210    pub fn unsplit(self, write_half: OwnedWritePty) -> crate::Result<Pty> {
211        let Self(read_pt) = self;
212        let OwnedWritePty(write_pt) = write_half;
213        if std::sync::Arc::ptr_eq(&read_pt, &write_pt) {
214            drop(write_pt);
215            Ok(Pty(std::sync::Arc::try_unwrap(read_pt)
216                // it shouldn't be possible for more than two references to
217                // the same pty to exist
218                .unwrap_or_else(|_| unreachable!())))
219        } else {
220            Err(crate::Error::Unsplit(
221                Self(read_pt),
222                OwnedWritePty(write_pt),
223            ))
224        }
225    }
226}
227
228impl tokio::io::AsyncRead for OwnedReadPty {
229    fn poll_read(
230        self: std::pin::Pin<&mut Self>,
231        cx: &mut std::task::Context<'_>,
232        buf: &mut tokio::io::ReadBuf,
233    ) -> std::task::Poll<std::io::Result<()>> {
234        poll_read(&self.0, cx, buf)
235    }
236}
237
238/// Owned write half of a [`Pty`]
239#[derive(Debug)]
240pub struct OwnedWritePty(std::sync::Arc<AsyncPty>);
241
242impl OwnedWritePty {
243    /// Change the terminal size associated with the pty.
244    ///
245    /// # Errors
246    /// Returns an error if we were unable to set the terminal size.
247    pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
248        self.0.get_ref().set_term_size(size)
249    }
250}
251
252impl tokio::io::AsyncWrite for OwnedWritePty {
253    fn poll_write(
254        self: std::pin::Pin<&mut Self>,
255        cx: &mut std::task::Context<'_>,
256        buf: &[u8],
257    ) -> std::task::Poll<std::io::Result<usize>> {
258        poll_write(&self.0, cx, buf)
259    }
260
261    fn poll_flush(
262        self: std::pin::Pin<&mut Self>,
263        cx: &mut std::task::Context<'_>,
264    ) -> std::task::Poll<std::io::Result<()>> {
265        poll_flush(&self.0, cx)
266    }
267
268    fn poll_shutdown(
269        self: std::pin::Pin<&mut Self>,
270        _cx: &mut std::task::Context<'_>,
271    ) -> std::task::Poll<Result<(), std::io::Error>> {
272        std::task::Poll::Ready(Ok(()))
273    }
274}
275
276fn poll_read(
277    pty: &AsyncPty,
278    cx: &mut std::task::Context<'_>,
279    buf: &mut tokio::io::ReadBuf,
280) -> std::task::Poll<std::io::Result<()>> {
281    loop {
282        let mut guard = match pty.poll_read_ready(cx) {
283            std::task::Poll::Ready(guard) => guard,
284            std::task::Poll::Pending => return std::task::Poll::Pending,
285        }?;
286        let prev_filled = buf.filled().len();
287        // SAFETY: we only pass b to read_buf, which never uninitializes any
288        // part of the buffer it is given
289        let b = unsafe { buf.unfilled_mut() };
290        match guard.try_io(|inner| inner.get_ref().read_buf(b)) {
291            Ok(Ok((filled, _unfilled))) => {
292                let bytes = filled.len();
293                // SAFETY: read_buf is given a buffer that starts at the end
294                // of the filled section, and then both initializes and fills
295                // some amount of the buffer after that (and never
296                // deinitializes anything). we know that at least this many
297                // bytes have been initialized (they either were filled and
298                // initialized previously, or the call to read_buf did), and
299                // assume_init will ignore any attempts to shrink the
300                // initialized space, so this call is always safe.
301                unsafe { buf.assume_init(prev_filled + bytes) };
302                buf.advance(bytes);
303                return std::task::Poll::Ready(Ok(()));
304            }
305            Ok(Err(e)) => return std::task::Poll::Ready(Err(e)),
306            Err(_would_block) => {}
307        }
308    }
309}
310
311fn poll_write(
312    pty: &AsyncPty,
313    cx: &mut std::task::Context<'_>,
314    buf: &[u8],
315) -> std::task::Poll<std::io::Result<usize>> {
316    loop {
317        let mut guard = match pty.poll_write_ready(cx) {
318            std::task::Poll::Ready(guard) => guard,
319            std::task::Poll::Pending => return std::task::Poll::Pending,
320        }?;
321        match guard.try_io(|inner| inner.get_ref().write(buf)) {
322            Ok(result) => return std::task::Poll::Ready(result),
323            Err(_would_block) => {}
324        }
325    }
326}
327
328fn poll_flush(
329    pty: &AsyncPty,
330    cx: &mut std::task::Context<'_>,
331) -> std::task::Poll<std::io::Result<()>> {
332    loop {
333        let mut guard = match pty.poll_write_ready(cx) {
334            std::task::Poll::Ready(guard) => guard,
335            std::task::Poll::Pending => return std::task::Poll::Pending,
336        }?;
337        match guard.try_io(|inner| inner.get_ref().flush()) {
338            Ok(_) => return std::task::Poll::Ready(Ok(())),
339            Err(_would_block) => {}
340        }
341    }
342}