use crate::cli::{IsTerminal, StdinStream, StdoutStream};
use crate::p2;
use bytes::Bytes;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio::sync::{Mutex, OwnedMutexGuard};
use wasmtime_wasi_io::streams::{InputStream, OutputStream};
trait SharedHandleReady: Send + Sync + 'static {
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
}
impl SharedHandleReady for p2::pipe::AsyncWriteStream {
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
<Self>::poll_ready(self, cx)
}
}
impl SharedHandleReady for p2::pipe::AsyncReadStream {
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
<Self>::poll_ready(self, cx)
}
}
pub struct AsyncStdinStream(Arc<Mutex<p2::pipe::AsyncReadStream>>);
impl AsyncStdinStream {
pub fn new(s: impl AsyncRead + Send + Sync + 'static) -> Self {
Self(Arc::new(Mutex::new(p2::pipe::AsyncReadStream::new(s))))
}
}
impl StdinStream for AsyncStdinStream {
fn p2_stream(&self) -> Box<dyn InputStream> {
Box::new(Self(self.0.clone()))
}
fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
Box::new(StdioHandle::Ready(self.0.clone()))
}
}
impl IsTerminal for AsyncStdinStream {
fn is_terminal(&self) -> bool {
false
}
}
#[async_trait::async_trait]
impl InputStream for AsyncStdinStream {
fn read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.read(size),
Err(_) => Err(p2::StreamError::trap("concurrent reads are not supported")),
}
}
fn skip(&mut self, size: usize) -> Result<usize, p2::StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.skip(size),
Err(_) => Err(p2::StreamError::trap("concurrent skips are not supported")),
}
}
async fn cancel(&mut self) {
if let Some(mutex) = Arc::get_mut(&mut self.0) {
match mutex.try_lock() {
Ok(mut stream) => stream.cancel().await,
Err(_) => {}
}
}
}
}
#[async_trait::async_trait]
impl p2::Pollable for AsyncStdinStream {
async fn ready(&mut self) {
self.0.lock().await.ready().await
}
}
impl AsyncRead for StdioHandle<p2::pipe::AsyncReadStream> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match ready!(self.as_mut().poll(cx, |g| g.read(buf.remaining()))) {
Some(Ok(bytes)) => {
buf.put_slice(&bytes);
Poll::Ready(Ok(()))
}
Some(Err(e)) => Poll::Ready(Err(e)),
None => Poll::Ready(Ok(())),
}
}
}
pub struct AsyncStdoutStream(Arc<Mutex<p2::pipe::AsyncWriteStream>>);
impl AsyncStdoutStream {
pub fn new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self {
Self(Arc::new(Mutex::new(p2::pipe::AsyncWriteStream::new(
budget, s,
))))
}
}
impl StdoutStream for AsyncStdoutStream {
fn p2_stream(&self) -> Box<dyn OutputStream> {
Box::new(Self(self.0.clone()))
}
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(StdioHandle::Ready(self.0.clone()))
}
}
impl IsTerminal for AsyncStdoutStream {
fn is_terminal(&self) -> bool {
false
}
}
#[async_trait::async_trait]
impl OutputStream for AsyncStdoutStream {
fn check_write(&mut self) -> Result<usize, p2::StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.check_write(),
Err(_) => Err(p2::StreamError::trap("concurrent writes are not supported")),
}
}
fn write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.write(bytes),
Err(_) => Err(p2::StreamError::trap("concurrent writes not supported yet")),
}
}
fn flush(&mut self) -> Result<(), p2::StreamError> {
match self.0.try_lock() {
Ok(mut stream) => stream.flush(),
Err(_) => Err(p2::StreamError::trap(
"concurrent flushes not supported yet",
)),
}
}
async fn cancel(&mut self) {
if let Some(mutex) = Arc::get_mut(&mut self.0) {
match mutex.try_lock() {
Ok(mut stream) => stream.cancel().await,
Err(_) => {}
}
}
}
}
#[async_trait::async_trait]
impl p2::Pollable for AsyncStdoutStream {
async fn ready(&mut self) {
self.0.lock().await.ready().await
}
}
impl AsyncWrite for StdioHandle<p2::pipe::AsyncWriteStream> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match ready!(self.poll(cx, |i| i.write(Bytes::copy_from_slice(buf)))) {
Some(Ok(())) => Poll::Ready(Ok(buf.len())),
Some(Err(e)) => Poll::Ready(Err(e)),
None => Poll::Ready(Ok(0)),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match ready!(self.poll(cx, |i| i.flush())) {
Some(result) => Poll::Ready(result),
None => Poll::Ready(Ok(())),
}
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
enum StdioHandle<S> {
Ready(Arc<Mutex<S>>),
Locking(Box<dyn Future<Output = OwnedMutexGuard<S>> + Send + Sync>),
Locked(OwnedMutexGuard<S>),
Closed,
}
impl<S> StdioHandle<S>
where
S: SharedHandleReady,
{
fn poll<T>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
op: impl FnOnce(&mut S) -> p2::StreamResult<T>,
) -> Poll<Option<io::Result<T>>> {
if let StdioHandle::Ready(lock) = &*self {
self.set(StdioHandle::Locking(Box::new(lock.clone().lock_owned())));
}
if let Some(lock) = self.as_mut().as_locking() {
let guard = ready!(lock.poll(cx));
self.set(StdioHandle::Locked(guard));
}
let mut guard = match self.as_mut().take_guard() {
Some(guard) => guard,
None => return Poll::Ready(None),
};
match guard.poll_ready(cx) {
Poll::Ready(()) => {}
Poll::Pending => {
self.set(StdioHandle::Locked(guard));
return Poll::Pending;
}
}
match op(&mut guard) {
Ok(result) => {
self.set(StdioHandle::Ready(OwnedMutexGuard::mutex(&guard).clone()));
Poll::Ready(Some(Ok(result)))
}
Err(p2::StreamError::Closed) => Poll::Ready(None),
Err(p2::StreamError::LastOperationFailed(e)) => {
Poll::Ready(Some(Err(e.downcast().unwrap())))
}
Err(p2::StreamError::Trap(_)) => unreachable!(),
}
}
fn as_locking(
self: Pin<&mut Self>,
) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>> {
unsafe {
match self.get_unchecked_mut() {
StdioHandle::Locking(future) => Some(Pin::new_unchecked(&mut **future)),
_ => None,
}
}
}
fn take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>> {
if !matches!(*self, StdioHandle::Locked(_)) {
return None;
}
unsafe {
match mem::replace(self.get_unchecked_mut(), StdioHandle::Closed) {
StdioHandle::Locked(guard) => Some(guard),
_ => unreachable!(),
}
}
}
}