Skip to main content

prek_pty/
pty.rs

1#![allow(clippy::module_name_repetitions)]
2use std::io::Write as _;
3
4type AsyncPty = tokio::io::unix::AsyncFd<crate::sys::Pty>;
5
6/// Allocate and return a new pty and pts.
7///
8/// # Errors
9/// Returns an error if the pty failed to be allocated, or if we were
10/// unable to put it into non-blocking mode.
11pub fn open() -> crate::Result<(Pty, Pts)> {
12    let pty = crate::sys::Pty::open()?;
13    let pts = pty.pts()?;
14    pty.set_nonblocking()?;
15    let pty = tokio::io::unix::AsyncFd::new(pty)?;
16    Ok((Pty(pty), Pts(pts)))
17}
18
19/// An allocated pty
20pub struct Pty(AsyncPty);
21
22impl Pty {
23    /// Use the provided file descriptor as a pty.
24    ///
25    /// # Safety
26    /// The provided file descriptor must be valid, open, belong to a pty,
27    /// and put into nonblocking mode.
28    ///
29    /// # Errors
30    /// Returns an error if it fails to be registered with the async runtime.
31    pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> crate::Result<Self> {
32        Ok(Self(tokio::io::unix::AsyncFd::new(unsafe {
33            crate::sys::Pty::from_fd(fd)
34        })?))
35    }
36
37    /// Change the terminal size associated with the pty.
38    ///
39    /// # Errors
40    /// Returns an error if we were unable to set the terminal size.
41    pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
42        self.0.get_ref().set_term_size(size)
43    }
44
45    /// Splits a `Pty` into a read half and a write half, which can be used to
46    /// read from and write to the pty concurrently. Does not allocate, but
47    /// the returned halves cannot be moved to independent tasks.
48    pub fn split(&self) -> (ReadPty<'_>, WritePty<'_>) {
49        (ReadPty(&self.0), WritePty(&self.0))
50    }
51
52    /// Splits a `Pty` into a read half and a write half, which can be used to
53    /// read from and write to the pty concurrently. This method requires an
54    /// allocation, but the returned halves can be moved to independent tasks.
55    /// The original `Pty` instance can be recovered via the
56    /// [`OwnedReadPty::unsplit`] method.
57    #[must_use]
58    pub fn into_split(self) -> (OwnedReadPty, OwnedWritePty) {
59        let Self(pt) = self;
60        let read_pt = std::sync::Arc::new(pt);
61        let write_pt = std::sync::Arc::clone(&read_pt);
62        (OwnedReadPty(read_pt), OwnedWritePty(write_pt))
63    }
64}
65
66impl From<Pty> for std::os::fd::OwnedFd {
67    fn from(pty: Pty) -> Self {
68        pty.0.into_inner().into()
69    }
70}
71
72impl std::os::fd::AsFd for Pty {
73    fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
74        self.0.get_ref().as_fd()
75    }
76}
77
78impl std::os::fd::AsRawFd for Pty {
79    fn as_raw_fd(&self) -> std::os::fd::RawFd {
80        self.0.get_ref().as_raw_fd()
81    }
82}
83
84impl tokio::io::AsyncRead for Pty {
85    fn poll_read(
86        self: std::pin::Pin<&mut Self>,
87        cx: &mut std::task::Context<'_>,
88        buf: &mut tokio::io::ReadBuf,
89    ) -> std::task::Poll<std::io::Result<()>> {
90        poll_read(&self.0, cx, buf)
91    }
92}
93
94impl tokio::io::AsyncWrite for Pty {
95    fn poll_write(
96        self: std::pin::Pin<&mut Self>,
97        cx: &mut std::task::Context<'_>,
98        buf: &[u8],
99    ) -> std::task::Poll<std::io::Result<usize>> {
100        poll_write(&self.0, cx, buf)
101    }
102
103    fn poll_flush(
104        self: std::pin::Pin<&mut Self>,
105        cx: &mut std::task::Context<'_>,
106    ) -> std::task::Poll<std::io::Result<()>> {
107        poll_flush(&self.0, cx)
108    }
109
110    fn poll_shutdown(
111        self: std::pin::Pin<&mut Self>,
112        _cx: &mut std::task::Context<'_>,
113    ) -> std::task::Poll<Result<(), std::io::Error>> {
114        std::task::Poll::Ready(Ok(()))
115    }
116}
117
118/// The child end of the pty
119///
120/// See [`open`] and [`Command::spawn`](crate::Command::spawn)
121pub struct Pts(pub(crate) crate::sys::Pts);
122
123impl Pts {
124    /// Use the provided file descriptor as a pts.
125    ///
126    /// # Safety
127    /// The provided file descriptor must be valid, open, and belong to the
128    /// child end of a pty.
129    #[must_use]
130    pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> Self {
131        Self(unsafe { crate::sys::Pts::from_fd(fd) })
132    }
133
134    pub fn setup_subprocess(
135        &self,
136    ) -> std::io::Result<(
137        std::process::Stdio,
138        std::process::Stdio,
139        std::process::Stdio,
140    )> {
141        self.0.setup_subprocess()
142    }
143
144    pub fn session_leader(&self) -> impl FnMut() -> std::io::Result<()> + use<> {
145        self.0.session_leader()
146    }
147}
148
149impl std::os::fd::AsFd for Pts {
150    fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
151        self.0.as_fd()
152    }
153}
154
155impl std::os::fd::AsRawFd for Pts {
156    fn as_raw_fd(&self) -> std::os::fd::RawFd {
157        self.0.as_raw_fd()
158    }
159}
160
161/// Borrowed read half of a [`Pty`]
162pub struct ReadPty<'a>(&'a AsyncPty);
163
164impl tokio::io::AsyncRead for ReadPty<'_> {
165    fn poll_read(
166        self: std::pin::Pin<&mut Self>,
167        cx: &mut std::task::Context<'_>,
168        buf: &mut tokio::io::ReadBuf,
169    ) -> std::task::Poll<std::io::Result<()>> {
170        poll_read(self.0, cx, buf)
171    }
172}
173
174/// Borrowed write half of a [`Pty`]
175pub struct WritePty<'a>(&'a AsyncPty);
176
177impl WritePty<'_> {
178    /// Change the terminal size associated with the pty.
179    ///
180    /// # Errors
181    /// Returns an error if we were unable to set the terminal size.
182    pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
183        self.0.get_ref().set_term_size(size)
184    }
185}
186
187impl tokio::io::AsyncWrite for WritePty<'_> {
188    fn poll_write(
189        self: std::pin::Pin<&mut Self>,
190        cx: &mut std::task::Context<'_>,
191        buf: &[u8],
192    ) -> std::task::Poll<std::io::Result<usize>> {
193        poll_write(self.0, cx, buf)
194    }
195
196    fn poll_flush(
197        self: std::pin::Pin<&mut Self>,
198        cx: &mut std::task::Context<'_>,
199    ) -> std::task::Poll<std::io::Result<()>> {
200        poll_flush(self.0, cx)
201    }
202
203    fn poll_shutdown(
204        self: std::pin::Pin<&mut Self>,
205        _cx: &mut std::task::Context<'_>,
206    ) -> std::task::Poll<Result<(), std::io::Error>> {
207        std::task::Poll::Ready(Ok(()))
208    }
209}
210
211/// Owned read half of a [`Pty`]
212#[derive(Debug)]
213pub struct OwnedReadPty(std::sync::Arc<AsyncPty>);
214
215impl OwnedReadPty {
216    /// Attempt to join the two halves of a `Pty` back into a single instance.
217    /// The two halves must have originated from calling
218    /// [`into_split`](Pty::into_split) on a single instance.
219    ///
220    /// # Errors
221    /// Returns an error if the two halves came from different [`Pty`]
222    /// instances. The mismatched halves are returned as part of the error.
223    pub fn unsplit(self, write_half: OwnedWritePty) -> crate::Result<Pty> {
224        let Self(read_pt) = self;
225        let OwnedWritePty(write_pt) = write_half;
226        if std::sync::Arc::ptr_eq(&read_pt, &write_pt) {
227            drop(write_pt);
228            Ok(Pty(std::sync::Arc::try_unwrap(read_pt)
229                // it shouldn't be possible for more than two references to
230                // the same pty to exist
231                .unwrap_or_else(|_| unreachable!())))
232        } else {
233            Err(crate::Error::Unsplit(
234                Self(read_pt),
235                OwnedWritePty(write_pt),
236            ))
237        }
238    }
239}
240
241impl tokio::io::AsyncRead for OwnedReadPty {
242    fn poll_read(
243        self: std::pin::Pin<&mut Self>,
244        cx: &mut std::task::Context<'_>,
245        buf: &mut tokio::io::ReadBuf,
246    ) -> std::task::Poll<std::io::Result<()>> {
247        poll_read(&self.0, cx, buf)
248    }
249}
250
251/// Owned write half of a [`Pty`]
252#[derive(Debug)]
253pub struct OwnedWritePty(std::sync::Arc<AsyncPty>);
254
255impl OwnedWritePty {
256    /// Change the terminal size associated with the pty.
257    ///
258    /// # Errors
259    /// Returns an error if we were unable to set the terminal size.
260    pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
261        self.0.get_ref().set_term_size(size)
262    }
263}
264
265impl tokio::io::AsyncWrite for OwnedWritePty {
266    fn poll_write(
267        self: std::pin::Pin<&mut Self>,
268        cx: &mut std::task::Context<'_>,
269        buf: &[u8],
270    ) -> std::task::Poll<std::io::Result<usize>> {
271        poll_write(&self.0, cx, buf)
272    }
273
274    fn poll_flush(
275        self: std::pin::Pin<&mut Self>,
276        cx: &mut std::task::Context<'_>,
277    ) -> std::task::Poll<std::io::Result<()>> {
278        poll_flush(&self.0, cx)
279    }
280
281    fn poll_shutdown(
282        self: std::pin::Pin<&mut Self>,
283        _cx: &mut std::task::Context<'_>,
284    ) -> std::task::Poll<Result<(), std::io::Error>> {
285        std::task::Poll::Ready(Ok(()))
286    }
287}
288
289fn poll_read(
290    pty: &AsyncPty,
291    cx: &mut std::task::Context<'_>,
292    buf: &mut tokio::io::ReadBuf,
293) -> std::task::Poll<std::io::Result<()>> {
294    loop {
295        let mut guard = match pty.poll_read_ready(cx) {
296            std::task::Poll::Ready(guard) => guard,
297            std::task::Poll::Pending => return std::task::Poll::Pending,
298        }?;
299        let prev_filled = buf.filled().len();
300        // SAFETY: we only pass b to read_buf, which never uninitializes any
301        // part of the buffer it is given
302        let b = unsafe { buf.unfilled_mut() };
303        match guard.try_io(|inner| inner.get_ref().read_buf(b)) {
304            Ok(Ok((filled, _unfilled))) => {
305                let bytes = filled.len();
306                // SAFETY: read_buf is given a buffer that starts at the end
307                // of the filled section, and then both initializes and fills
308                // some amount of the buffer after that (and never
309                // deinitializes anything). we know that at least this many
310                // bytes have been initialized (they either were filled and
311                // initialized previously, or the call to read_buf did), and
312                // assume_init will ignore any attempts to shrink the
313                // initialized space, so this call is always safe.
314                unsafe { buf.assume_init(prev_filled + bytes) };
315                buf.advance(bytes);
316                return std::task::Poll::Ready(Ok(()));
317            }
318            Ok(Err(e)) => return std::task::Poll::Ready(Err(e)),
319            Err(_would_block) => {}
320        }
321    }
322}
323
324fn poll_write(
325    pty: &AsyncPty,
326    cx: &mut std::task::Context<'_>,
327    buf: &[u8],
328) -> std::task::Poll<std::io::Result<usize>> {
329    loop {
330        let mut guard = match pty.poll_write_ready(cx) {
331            std::task::Poll::Ready(guard) => guard,
332            std::task::Poll::Pending => return std::task::Poll::Pending,
333        }?;
334        match guard.try_io(|inner| inner.get_ref().write(buf)) {
335            Ok(result) => return std::task::Poll::Ready(result),
336            Err(_would_block) => {}
337        }
338    }
339}
340
341fn poll_flush(
342    pty: &AsyncPty,
343    cx: &mut std::task::Context<'_>,
344) -> std::task::Poll<std::io::Result<()>> {
345    loop {
346        let mut guard = match pty.poll_write_ready(cx) {
347            std::task::Poll::Ready(guard) => guard,
348            std::task::Poll::Pending => return std::task::Poll::Pending,
349        }?;
350        match guard.try_io(|inner| inner.get_ref().flush()) {
351            Ok(_) => return std::task::Poll::Ready(Ok(())),
352            Err(_would_block) => {}
353        }
354    }
355}