#![allow(clippy::module_name_repetitions)]
use std::io::Write as _;
type AsyncPty = tokio::io::unix::AsyncFd<crate::sys::Pty>;
pub fn open() -> crate::Result<(Pty, Pts)> {
let pty = crate::sys::Pty::open()?;
let pts = pty.pts()?;
pty.set_nonblocking()?;
let pty = tokio::io::unix::AsyncFd::new(pty)?;
Ok((Pty(pty), Pts(pts)))
}
pub struct Pty(AsyncPty);
impl Pty {
pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> crate::Result<Self> {
Ok(Self(tokio::io::unix::AsyncFd::new(unsafe {
crate::sys::Pty::from_fd(fd)
})?))
}
pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
self.0.get_ref().set_term_size(size)
}
pub fn split(&self) -> (ReadPty<'_>, WritePty<'_>) {
(ReadPty(&self.0), WritePty(&self.0))
}
#[must_use]
pub fn into_split(self) -> (OwnedReadPty, OwnedWritePty) {
let Self(pt) = self;
let read_pt = std::sync::Arc::new(pt);
let write_pt = std::sync::Arc::clone(&read_pt);
(OwnedReadPty(read_pt), OwnedWritePty(write_pt))
}
}
impl From<Pty> for std::os::fd::OwnedFd {
fn from(pty: Pty) -> Self {
pty.0.into_inner().into()
}
}
impl std::os::fd::AsFd for Pty {
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
self.0.get_ref().as_fd()
}
}
impl std::os::fd::AsRawFd for Pty {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.0.get_ref().as_raw_fd()
}
}
impl tokio::io::AsyncRead for Pty {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf,
) -> std::task::Poll<std::io::Result<()>> {
poll_read(&self.0, cx, buf)
}
}
impl tokio::io::AsyncWrite for Pty {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
poll_write(&self.0, cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
poll_flush(&self.0, cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
pub struct Pts(pub(crate) crate::sys::Pts);
impl Pts {
#[must_use]
pub unsafe fn from_fd(fd: std::os::fd::OwnedFd) -> Self {
Self(unsafe { crate::sys::Pts::from_fd(fd) })
}
pub fn setup_subprocess(
&self,
) -> std::io::Result<(
std::process::Stdio,
std::process::Stdio,
std::process::Stdio,
)> {
self.0.setup_subprocess()
}
pub fn session_leader(&self) -> impl FnMut() -> std::io::Result<()> + use<> {
self.0.session_leader()
}
}
impl std::os::fd::AsFd for Pts {
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
self.0.as_fd()
}
}
impl std::os::fd::AsRawFd for Pts {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.0.as_raw_fd()
}
}
pub struct ReadPty<'a>(&'a AsyncPty);
impl tokio::io::AsyncRead for ReadPty<'_> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf,
) -> std::task::Poll<std::io::Result<()>> {
poll_read(self.0, cx, buf)
}
}
pub struct WritePty<'a>(&'a AsyncPty);
impl WritePty<'_> {
pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
self.0.get_ref().set_term_size(size)
}
}
impl tokio::io::AsyncWrite for WritePty<'_> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
poll_write(self.0, cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
poll_flush(self.0, cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub struct OwnedReadPty(std::sync::Arc<AsyncPty>);
impl OwnedReadPty {
pub fn unsplit(self, write_half: OwnedWritePty) -> crate::Result<Pty> {
let Self(read_pt) = self;
let OwnedWritePty(write_pt) = write_half;
if std::sync::Arc::ptr_eq(&read_pt, &write_pt) {
drop(write_pt);
Ok(Pty(std::sync::Arc::try_unwrap(read_pt)
.unwrap_or_else(|_| unreachable!())))
} else {
Err(crate::Error::Unsplit(
Self(read_pt),
OwnedWritePty(write_pt),
))
}
}
}
impl tokio::io::AsyncRead for OwnedReadPty {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf,
) -> std::task::Poll<std::io::Result<()>> {
poll_read(&self.0, cx, buf)
}
}
#[derive(Debug)]
pub struct OwnedWritePty(std::sync::Arc<AsyncPty>);
impl OwnedWritePty {
pub fn resize(&self, size: crate::Size) -> crate::Result<()> {
self.0.get_ref().set_term_size(size)
}
}
impl tokio::io::AsyncWrite for OwnedWritePty {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
poll_write(&self.0, cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
poll_flush(&self.0, cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
fn poll_read(
pty: &AsyncPty,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf,
) -> std::task::Poll<std::io::Result<()>> {
loop {
let mut guard = match pty.poll_read_ready(cx) {
std::task::Poll::Ready(guard) => guard,
std::task::Poll::Pending => return std::task::Poll::Pending,
}?;
let prev_filled = buf.filled().len();
let b = unsafe { buf.unfilled_mut() };
match guard.try_io(|inner| inner.get_ref().read_buf(b)) {
Ok(Ok((filled, _unfilled))) => {
let bytes = filled.len();
unsafe { buf.assume_init(prev_filled + bytes) };
buf.advance(bytes);
return std::task::Poll::Ready(Ok(()));
}
Ok(Err(e)) => return std::task::Poll::Ready(Err(e)),
Err(_would_block) => {}
}
}
}
fn poll_write(
pty: &AsyncPty,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
loop {
let mut guard = match pty.poll_write_ready(cx) {
std::task::Poll::Ready(guard) => guard,
std::task::Poll::Pending => return std::task::Poll::Pending,
}?;
match guard.try_io(|inner| inner.get_ref().write(buf)) {
Ok(result) => return std::task::Poll::Ready(result),
Err(_would_block) => {}
}
}
}
fn poll_flush(
pty: &AsyncPty,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
loop {
let mut guard = match pty.poll_write_ready(cx) {
std::task::Poll::Ready(guard) => guard,
std::task::Poll::Pending => return std::task::Poll::Pending,
}?;
match guard.try_io(|inner| inner.get_ref().flush()) {
Ok(_) => return std::task::Poll::Ready(Ok(())),
Err(_would_block) => {}
}
}
}