use crate::cli::{IsTerminal, StdinStream};
use bytes::{Bytes, BytesMut};
use std::io::Read;
use std::mem;
use std::pin::Pin;
use std::sync::{Condvar, Mutex, OnceLock};
use std::task::{Context, Poll};
use tokio::io::{self, AsyncRead, ReadBuf};
use tokio::sync::Notify;
use tokio::sync::futures::Notified;
use wasmtime_wasi_io::{
poll::Pollable,
streams::{InputStream, StreamError},
};
impl IsTerminal for tokio::io::Stdin {
fn is_terminal(&self) -> bool {
std::io::stdin().is_terminal()
}
}
impl StdinStream for tokio::io::Stdin {
fn p2_stream(&self) -> Box<dyn InputStream> {
Box::new(WasiStdin)
}
fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
Box::new(WasiStdinAsyncRead::Ready)
}
}
impl IsTerminal for std::io::Stdin {
fn is_terminal(&self) -> bool {
std::io::IsTerminal::is_terminal(self)
}
}
impl StdinStream for std::io::Stdin {
fn p2_stream(&self) -> Box<dyn InputStream> {
Box::new(WasiStdin)
}
fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
Box::new(WasiStdinAsyncRead::Ready)
}
}
#[derive(Default)]
struct GlobalStdin {
state: Mutex<StdinState>,
read_requested: Condvar,
read_completed: Notify,
}
#[derive(Default, Debug)]
enum StdinState {
#[default]
ReadNotRequested,
ReadRequested,
Data(BytesMut),
Error(std::io::Error),
Closed,
}
impl GlobalStdin {
fn get() -> &'static GlobalStdin {
static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
STDIN.get_or_init(|| create())
}
}
fn create() -> GlobalStdin {
std::thread::spawn(|| {
let state = GlobalStdin::get();
loop {
let mut lock = state.state.lock().unwrap();
lock = state
.read_requested
.wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
.unwrap();
drop(lock);
let mut bytes = BytesMut::zeroed(1024);
let (new_state, done) = match std::io::stdin().read(&mut bytes) {
Ok(0) => (StdinState::Closed, true),
Ok(nbytes) => {
bytes.truncate(nbytes);
(StdinState::Data(bytes), false)
}
Err(e) => (StdinState::Error(e), true),
};
debug_assert!(matches!(
*state.state.lock().unwrap(),
StdinState::ReadRequested
));
*state.state.lock().unwrap() = new_state;
state.read_completed.notify_waiters();
if done {
break;
}
}
});
GlobalStdin::default()
}
struct WasiStdin;
#[async_trait::async_trait]
impl InputStream for WasiStdin {
fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
let g = GlobalStdin::get();
let mut locked = g.state.lock().unwrap();
match mem::replace(&mut *locked, StdinState::ReadRequested) {
StdinState::ReadNotRequested => {
g.read_requested.notify_one();
Ok(Bytes::new())
}
StdinState::ReadRequested => Ok(Bytes::new()),
StdinState::Data(mut data) => {
let size = data.len().min(size);
let bytes = data.split_to(size);
*locked = if data.is_empty() {
StdinState::ReadNotRequested
} else {
StdinState::Data(data)
};
Ok(bytes.freeze())
}
StdinState::Error(e) => {
*locked = StdinState::Closed;
Err(StreamError::LastOperationFailed(e.into()))
}
StdinState::Closed => {
*locked = StdinState::Closed;
Err(StreamError::Closed)
}
}
}
}
#[async_trait::async_trait]
impl Pollable for WasiStdin {
async fn ready(&mut self) {
let g = GlobalStdin::get();
let notified = {
let mut locked = g.state.lock().unwrap();
match *locked {
StdinState::ReadNotRequested => {
g.read_requested.notify_one();
*locked = StdinState::ReadRequested;
g.read_completed.notified()
}
StdinState::ReadRequested => g.read_completed.notified(),
StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
}
};
notified.await;
}
}
enum WasiStdinAsyncRead {
Ready,
Waiting(Notified<'static>),
}
impl AsyncRead for WasiStdinAsyncRead {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let g = GlobalStdin::get();
loop {
if let Some(notified) = self.as_mut().notified_future() {
match notified.poll(cx) {
Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
Poll::Pending => break Poll::Pending,
}
}
assert!(matches!(*self, WasiStdinAsyncRead::Ready));
let mut locked = g.state.lock().unwrap();
match mem::replace(&mut *locked, StdinState::ReadRequested) {
StdinState::Data(mut data) => {
let size = data.len().min(buf.remaining());
let bytes = data.split_to(size);
*locked = if data.is_empty() {
StdinState::ReadNotRequested
} else {
StdinState::Data(data)
};
buf.put_slice(&bytes);
break Poll::Ready(Ok(()));
}
StdinState::Error(e) => {
*locked = StdinState::Closed;
break Poll::Ready(Err(e));
}
StdinState::Closed => {
*locked = StdinState::Closed;
break Poll::Ready(Ok(()));
}
StdinState::ReadNotRequested => {
g.read_requested.notify_one();
}
StdinState::ReadRequested => {}
}
self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
drop(locked);
}
}
}
impl WasiStdinAsyncRead {
fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
unsafe {
match self.get_unchecked_mut() {
WasiStdinAsyncRead::Ready => None,
WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
}
}
}
}