1#![allow(clippy::module_name_repetitions)]
2use std::io::Write as _;
3
4type AsyncPty = tokio::io::unix::AsyncFd<crate::sys::Pty>;
5
6pub fn open() -> crate::Result<(Pty, Pts)> {
12 let pty = crate::sys::Pty::open()?;
13 let pts = pty.pts()?;
14 pty.set_nonblocking()?;
15 let pty = tokio::io::unix::AsyncFd::new(pty)?;
16 Ok((Pty(pty), Pts(pts)))
17}
18
19pub struct Pty(AsyncPty);
21
22impl Pty {
23 pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> crate::Result<Self> {
32 Ok(Self(tokio::io::unix::AsyncFd::new(unsafe {
33 crate::sys::Pty::from_fd(fd)
34 })?))
35 }
36
37 pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
42 self.0.get_ref().set_term_size(size)
43 }
44
45 pub fn split(&self) -> (ReadPty<'_>, WritePty<'_>) {
49 (ReadPty(&self.0), WritePty(&self.0))
50 }
51
52 #[must_use]
58 pub fn into_split(self) -> (OwnedReadPty, OwnedWritePty) {
59 let Self(pt) = self;
60 let read_pt = std::sync::Arc::new(pt);
61 let write_pt = std::sync::Arc::clone(&read_pt);
62 (OwnedReadPty(read_pt), OwnedWritePty(write_pt))
63 }
64}
65
66impl From<Pty> for std::os::fd::OwnedFd {
67 fn from(pty: Pty) -> Self {
68 pty.0.into_inner().into()
69 }
70}
71
72impl std::os::fd::AsFd for Pty {
73 fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
74 self.0.get_ref().as_fd()
75 }
76}
77
78impl std::os::fd::AsRawFd for Pty {
79 fn as_raw_fd(&self) -> std::os::fd::RawFd {
80 self.0.get_ref().as_raw_fd()
81 }
82}
83
84impl tokio::io::AsyncRead for Pty {
85 fn poll_read(
86 self: std::pin::Pin<&mut Self>,
87 cx: &mut std::task::Context<'_>,
88 buf: &mut tokio::io::ReadBuf,
89 ) -> std::task::Poll<std::io::Result<()>> {
90 poll_read(&self.0, cx, buf)
91 }
92}
93
94impl tokio::io::AsyncWrite for Pty {
95 fn poll_write(
96 self: std::pin::Pin<&mut Self>,
97 cx: &mut std::task::Context<'_>,
98 buf: &[u8],
99 ) -> std::task::Poll<std::io::Result<usize>> {
100 poll_write(&self.0, cx, buf)
101 }
102
103 fn poll_flush(
104 self: std::pin::Pin<&mut Self>,
105 cx: &mut std::task::Context<'_>,
106 ) -> std::task::Poll<std::io::Result<()>> {
107 poll_flush(&self.0, cx)
108 }
109
110 fn poll_shutdown(
111 self: std::pin::Pin<&mut Self>,
112 _cx: &mut std::task::Context<'_>,
113 ) -> std::task::Poll<Result<(), std::io::Error>> {
114 std::task::Poll::Ready(Ok(()))
115 }
116}
117
118pub struct Pts(pub(crate) crate::sys::Pts);
122
123impl Pts {
124 #[must_use]
130 pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> Self {
131 Self(unsafe { crate::sys::Pts::from_fd(fd) })
132 }
133
134 pub fn setup_subprocess(
135 &self,
136 ) -> std::io::Result<(
137 std::process::Stdio,
138 std::process::Stdio,
139 std::process::Stdio,
140 )> {
141 self.0.setup_subprocess()
142 }
143
144 pub fn session_leader(&self) -> impl FnMut() -> std::io::Result<()> + use<> {
145 self.0.session_leader()
146 }
147}
148
149impl std::os::fd::AsFd for Pts {
150 fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
151 self.0.as_fd()
152 }
153}
154
155impl std::os::fd::AsRawFd for Pts {
156 fn as_raw_fd(&self) -> std::os::fd::RawFd {
157 self.0.as_raw_fd()
158 }
159}
160
161pub struct ReadPty<'a>(&'a AsyncPty);
163
164impl tokio::io::AsyncRead for ReadPty<'_> {
165 fn poll_read(
166 self: std::pin::Pin<&mut Self>,
167 cx: &mut std::task::Context<'_>,
168 buf: &mut tokio::io::ReadBuf,
169 ) -> std::task::Poll<std::io::Result<()>> {
170 poll_read(self.0, cx, buf)
171 }
172}
173
174pub struct WritePty<'a>(&'a AsyncPty);
176
177impl WritePty<'_> {
178 pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
183 self.0.get_ref().set_term_size(size)
184 }
185}
186
187impl tokio::io::AsyncWrite for WritePty<'_> {
188 fn poll_write(
189 self: std::pin::Pin<&mut Self>,
190 cx: &mut std::task::Context<'_>,
191 buf: &[u8],
192 ) -> std::task::Poll<std::io::Result<usize>> {
193 poll_write(self.0, cx, buf)
194 }
195
196 fn poll_flush(
197 self: std::pin::Pin<&mut Self>,
198 cx: &mut std::task::Context<'_>,
199 ) -> std::task::Poll<std::io::Result<()>> {
200 poll_flush(self.0, cx)
201 }
202
203 fn poll_shutdown(
204 self: std::pin::Pin<&mut Self>,
205 _cx: &mut std::task::Context<'_>,
206 ) -> std::task::Poll<Result<(), std::io::Error>> {
207 std::task::Poll::Ready(Ok(()))
208 }
209}
210
211#[derive(Debug)]
213pub struct OwnedReadPty(std::sync::Arc<AsyncPty>);
214
215impl OwnedReadPty {
216 pub fn unsplit(self, write_half: OwnedWritePty) -> crate::Result<Pty> {
224 let Self(read_pt) = self;
225 let OwnedWritePty(write_pt) = write_half;
226 if std::sync::Arc::ptr_eq(&read_pt, &write_pt) {
227 drop(write_pt);
228 Ok(Pty(std::sync::Arc::try_unwrap(read_pt)
229 .unwrap_or_else(|_| unreachable!())))
232 } else {
233 Err(crate::Error::Unsplit(
234 Self(read_pt),
235 OwnedWritePty(write_pt),
236 ))
237 }
238 }
239}
240
241impl tokio::io::AsyncRead for OwnedReadPty {
242 fn poll_read(
243 self: std::pin::Pin<&mut Self>,
244 cx: &mut std::task::Context<'_>,
245 buf: &mut tokio::io::ReadBuf,
246 ) -> std::task::Poll<std::io::Result<()>> {
247 poll_read(&self.0, cx, buf)
248 }
249}
250
251#[derive(Debug)]
253pub struct OwnedWritePty(std::sync::Arc<AsyncPty>);
254
255impl OwnedWritePty {
256 pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
261 self.0.get_ref().set_term_size(size)
262 }
263}
264
265impl tokio::io::AsyncWrite for OwnedWritePty {
266 fn poll_write(
267 self: std::pin::Pin<&mut Self>,
268 cx: &mut std::task::Context<'_>,
269 buf: &[u8],
270 ) -> std::task::Poll<std::io::Result<usize>> {
271 poll_write(&self.0, cx, buf)
272 }
273
274 fn poll_flush(
275 self: std::pin::Pin<&mut Self>,
276 cx: &mut std::task::Context<'_>,
277 ) -> std::task::Poll<std::io::Result<()>> {
278 poll_flush(&self.0, cx)
279 }
280
281 fn poll_shutdown(
282 self: std::pin::Pin<&mut Self>,
283 _cx: &mut std::task::Context<'_>,
284 ) -> std::task::Poll<Result<(), std::io::Error>> {
285 std::task::Poll::Ready(Ok(()))
286 }
287}
288
289fn poll_read(
290 pty: &AsyncPty,
291 cx: &mut std::task::Context<'_>,
292 buf: &mut tokio::io::ReadBuf,
293) -> std::task::Poll<std::io::Result<()>> {
294 loop {
295 let mut guard = match pty.poll_read_ready(cx) {
296 std::task::Poll::Ready(guard) => guard,
297 std::task::Poll::Pending => return std::task::Poll::Pending,
298 }?;
299 let prev_filled = buf.filled().len();
300 let b = unsafe { buf.unfilled_mut() };
303 match guard.try_io(|inner| inner.get_ref().read_buf(b)) {
304 Ok(Ok((filled, _unfilled))) => {
305 let bytes = filled.len();
306 unsafe { buf.assume_init(prev_filled + bytes) };
315 buf.advance(bytes);
316 return std::task::Poll::Ready(Ok(()));
317 }
318 Ok(Err(e)) => return std::task::Poll::Ready(Err(e)),
319 Err(_would_block) => {}
320 }
321 }
322}
323
324fn poll_write(
325 pty: &AsyncPty,
326 cx: &mut std::task::Context<'_>,
327 buf: &[u8],
328) -> std::task::Poll<std::io::Result<usize>> {
329 loop {
330 let mut guard = match pty.poll_write_ready(cx) {
331 std::task::Poll::Ready(guard) => guard,
332 std::task::Poll::Pending => return std::task::Poll::Pending,
333 }?;
334 match guard.try_io(|inner| inner.get_ref().write(buf)) {
335 Ok(result) => return std::task::Poll::Ready(result),
336 Err(_would_block) => {}
337 }
338 }
339}
340
341fn poll_flush(
342 pty: &AsyncPty,
343 cx: &mut std::task::Context<'_>,
344) -> std::task::Poll<std::io::Result<()>> {
345 loop {
346 let mut guard = match pty.poll_write_ready(cx) {
347 std::task::Poll::Ready(guard) => guard,
348 std::task::Poll::Pending => return std::task::Poll::Pending,
349 }?;
350 match guard.try_io(|inner| inner.get_ref().flush()) {
351 Ok(_) => return std::task::Poll::Ready(Ok(())),
352 Err(_would_block) => {}
353 }
354 }
355}