1#![allow(clippy::module_name_repetitions)]
2
3use std::io::Write as _;
4
5type AsyncPty = tokio::io::unix::AsyncFd<crate::sys::Pty>;
6
7pub fn open() -> crate::Result<(Pty, Pts)> {
13 let pty = crate::sys::Pty::open()?;
14 let pts = pty.pts()?;
15 pty.set_nonblocking()?;
16 let pty = tokio::io::unix::AsyncFd::new(pty)?;
17 Ok((Pty(pty), Pts(pts)))
18}
19
20pub struct Pty(AsyncPty);
22
23impl Pty {
24 pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> crate::Result<Self> {
33 Ok(Self(tokio::io::unix::AsyncFd::new(unsafe {
34 crate::sys::Pty::from_fd(fd)
35 })?))
36 }
37
38 pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
43 self.0.get_ref().set_term_size(size)
44 }
45
46 pub fn split(&mut self) -> (ReadPty<'_>, WritePty<'_>) {
50 (ReadPty(&self.0), WritePty(&self.0))
51 }
52
53 #[must_use]
59 pub fn into_split(self) -> (OwnedReadPty, OwnedWritePty) {
60 let Self(pt) = self;
61 let read_pt = std::sync::Arc::new(pt);
62 let write_pt = std::sync::Arc::clone(&read_pt);
63 (OwnedReadPty(read_pt), OwnedWritePty(write_pt))
64 }
65}
66
67impl From<Pty> for std::os::fd::OwnedFd {
68 fn from(pty: Pty) -> Self {
69 pty.0.into_inner().into()
70 }
71}
72
73impl std::os::fd::AsFd for Pty {
74 fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
75 self.0.get_ref().as_fd()
76 }
77}
78
79impl std::os::fd::AsRawFd for Pty {
80 fn as_raw_fd(&self) -> std::os::fd::RawFd {
81 self.0.get_ref().as_raw_fd()
82 }
83}
84
85impl tokio::io::AsyncRead for Pty {
86 fn poll_read(
87 self: std::pin::Pin<&mut Self>,
88 cx: &mut std::task::Context<'_>,
89 buf: &mut tokio::io::ReadBuf,
90 ) -> std::task::Poll<std::io::Result<()>> {
91 poll_read(&self.0, cx, buf)
92 }
93}
94
95impl tokio::io::AsyncWrite for Pty {
96 fn poll_write(
97 self: std::pin::Pin<&mut Self>,
98 cx: &mut std::task::Context<'_>,
99 buf: &[u8],
100 ) -> std::task::Poll<std::io::Result<usize>> {
101 poll_write(&self.0, cx, buf)
102 }
103
104 fn poll_flush(
105 self: std::pin::Pin<&mut Self>,
106 cx: &mut std::task::Context<'_>,
107 ) -> std::task::Poll<std::io::Result<()>> {
108 poll_flush(&self.0, cx)
109 }
110
111 fn poll_shutdown(
112 self: std::pin::Pin<&mut Self>,
113 _cx: &mut std::task::Context<'_>,
114 ) -> std::task::Poll<Result<(), std::io::Error>> {
115 std::task::Poll::Ready(Ok(()))
116 }
117}
118
119pub struct Pts(pub(crate) crate::sys::Pts);
123
124impl Pts {
125 #[must_use]
131 pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> Self {
132 Self(unsafe { crate::sys::Pts::from_fd(fd) })
133 }
134}
135
136impl std::os::fd::AsFd for Pts {
137 fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
138 self.0.as_fd()
139 }
140}
141
142impl std::os::fd::AsRawFd for Pts {
143 fn as_raw_fd(&self) -> std::os::fd::RawFd {
144 self.0.as_raw_fd()
145 }
146}
147
148pub struct ReadPty<'a>(&'a AsyncPty);
150
151impl tokio::io::AsyncRead for ReadPty<'_> {
152 fn poll_read(
153 self: std::pin::Pin<&mut Self>,
154 cx: &mut std::task::Context<'_>,
155 buf: &mut tokio::io::ReadBuf,
156 ) -> std::task::Poll<std::io::Result<()>> {
157 poll_read(self.0, cx, buf)
158 }
159}
160
161pub struct WritePty<'a>(&'a AsyncPty);
163
164impl WritePty<'_> {
165 pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
170 self.0.get_ref().set_term_size(size)
171 }
172}
173
174impl tokio::io::AsyncWrite for WritePty<'_> {
175 fn poll_write(
176 self: std::pin::Pin<&mut Self>,
177 cx: &mut std::task::Context<'_>,
178 buf: &[u8],
179 ) -> std::task::Poll<std::io::Result<usize>> {
180 poll_write(self.0, cx, buf)
181 }
182
183 fn poll_flush(
184 self: std::pin::Pin<&mut Self>,
185 cx: &mut std::task::Context<'_>,
186 ) -> std::task::Poll<std::io::Result<()>> {
187 poll_flush(self.0, cx)
188 }
189
190 fn poll_shutdown(
191 self: std::pin::Pin<&mut Self>,
192 _cx: &mut std::task::Context<'_>,
193 ) -> std::task::Poll<Result<(), std::io::Error>> {
194 std::task::Poll::Ready(Ok(()))
195 }
196}
197
198#[derive(Debug)]
200pub struct OwnedReadPty(std::sync::Arc<AsyncPty>);
201
202impl OwnedReadPty {
203 pub fn unsplit(self, write_half: OwnedWritePty) -> crate::Result<Pty> {
211 let Self(read_pt) = self;
212 let OwnedWritePty(write_pt) = write_half;
213 if std::sync::Arc::ptr_eq(&read_pt, &write_pt) {
214 drop(write_pt);
215 Ok(Pty(std::sync::Arc::try_unwrap(read_pt)
216 .unwrap_or_else(|_| unreachable!())))
219 } else {
220 Err(crate::Error::Unsplit(
221 Self(read_pt),
222 OwnedWritePty(write_pt),
223 ))
224 }
225 }
226}
227
228impl tokio::io::AsyncRead for OwnedReadPty {
229 fn poll_read(
230 self: std::pin::Pin<&mut Self>,
231 cx: &mut std::task::Context<'_>,
232 buf: &mut tokio::io::ReadBuf,
233 ) -> std::task::Poll<std::io::Result<()>> {
234 poll_read(&self.0, cx, buf)
235 }
236}
237
238#[derive(Debug)]
240pub struct OwnedWritePty(std::sync::Arc<AsyncPty>);
241
242impl OwnedWritePty {
243 pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
248 self.0.get_ref().set_term_size(size)
249 }
250}
251
252impl tokio::io::AsyncWrite for OwnedWritePty {
253 fn poll_write(
254 self: std::pin::Pin<&mut Self>,
255 cx: &mut std::task::Context<'_>,
256 buf: &[u8],
257 ) -> std::task::Poll<std::io::Result<usize>> {
258 poll_write(&self.0, cx, buf)
259 }
260
261 fn poll_flush(
262 self: std::pin::Pin<&mut Self>,
263 cx: &mut std::task::Context<'_>,
264 ) -> std::task::Poll<std::io::Result<()>> {
265 poll_flush(&self.0, cx)
266 }
267
268 fn poll_shutdown(
269 self: std::pin::Pin<&mut Self>,
270 _cx: &mut std::task::Context<'_>,
271 ) -> std::task::Poll<Result<(), std::io::Error>> {
272 std::task::Poll::Ready(Ok(()))
273 }
274}
275
276fn poll_read(
277 pty: &AsyncPty,
278 cx: &mut std::task::Context<'_>,
279 buf: &mut tokio::io::ReadBuf,
280) -> std::task::Poll<std::io::Result<()>> {
281 loop {
282 let mut guard = match pty.poll_read_ready(cx) {
283 std::task::Poll::Ready(guard) => guard,
284 std::task::Poll::Pending => return std::task::Poll::Pending,
285 }?;
286 let prev_filled = buf.filled().len();
287 let b = unsafe { buf.unfilled_mut() };
290 match guard.try_io(|inner| inner.get_ref().read_buf(b)) {
291 Ok(Ok((filled, _unfilled))) => {
292 let bytes = filled.len();
293 unsafe { buf.assume_init(prev_filled + bytes) };
302 buf.advance(bytes);
303 return std::task::Poll::Ready(Ok(()));
304 }
305 Ok(Err(e)) => return std::task::Poll::Ready(Err(e)),
306 Err(_would_block) => {}
307 }
308 }
309}
310
311fn poll_write(
312 pty: &AsyncPty,
313 cx: &mut std::task::Context<'_>,
314 buf: &[u8],
315) -> std::task::Poll<std::io::Result<usize>> {
316 loop {
317 let mut guard = match pty.poll_write_ready(cx) {
318 std::task::Poll::Ready(guard) => guard,
319 std::task::Poll::Pending => return std::task::Poll::Pending,
320 }?;
321 match guard.try_io(|inner| inner.get_ref().write(buf)) {
322 Ok(result) => return std::task::Poll::Ready(result),
323 Err(_would_block) => {}
324 }
325 }
326}
327
328fn poll_flush(
329 pty: &AsyncPty,
330 cx: &mut std::task::Context<'_>,
331) -> std::task::Poll<std::io::Result<()>> {
332 loop {
333 let mut guard = match pty.poll_write_ready(cx) {
334 std::task::Poll::Ready(guard) => guard,
335 std::task::Poll::Pending => return std::task::Poll::Pending,
336 }?;
337 match guard.try_io(|inner| inner.get_ref().flush()) {
338 Ok(_) => return std::task::Poll::Ready(Ok(())),
339 Err(_would_block) => {}
340 }
341 }
342}