use std::fmt;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::process::Stdio;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWriteExt};
use tokio::sync::Mutex as AsyncMutex;
use tokio_stream::{Stream, StreamExt};
type SharedReader = Arc<AsyncMutex<Option<Pin<Box<dyn AsyncRead + Send>>>>>;
type SharedLines = Arc<AsyncMutex<Option<Pin<Box<dyn Stream<Item = String> + Send>>>>>;
#[derive(Clone)]
pub struct Stdin(Source);
#[derive(Clone)]
enum Source {
Empty,
Bytes(Vec<u8>),
File(PathBuf),
Reader(SharedReader),
Lines(SharedLines),
}
impl Stdin {
pub fn empty() -> Self {
Stdin(Source::Empty)
}
pub fn from_string(text: impl Into<String>) -> Self {
Stdin(Source::Bytes(text.into().into_bytes()))
}
pub fn from_bytes(bytes: impl Into<Vec<u8>>) -> Self {
Stdin(Source::Bytes(bytes.into()))
}
pub fn from_file(path: impl AsRef<Path>) -> Self {
Stdin(Source::File(path.as_ref().to_path_buf()))
}
pub fn from_iter_lines<I, S>(lines: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut buf = Vec::new();
for line in lines {
buf.extend_from_slice(line.as_ref().as_bytes());
buf.push(b'\n');
}
Stdin(Source::Bytes(buf))
}
pub fn from_reader<R>(reader: R) -> Self
where
R: AsyncRead + Send + 'static,
{
Stdin(Source::Reader(Arc::new(AsyncMutex::new(Some(Box::pin(
reader,
))))))
}
pub fn from_lines<S>(lines: S) -> Self
where
S: Stream<Item = String> + Send + 'static,
{
Stdin(Source::Lines(Arc::new(AsyncMutex::new(Some(Box::pin(
lines,
))))))
}
pub(crate) fn is_empty(&self) -> bool {
matches!(self.0, Source::Empty)
}
pub(crate) fn stdio(&self) -> Stdio {
if self.is_empty() {
Stdio::null()
} else {
Stdio::piped()
}
}
pub(crate) async fn write_to(
&self,
sink: &mut tokio::process::ChildStdin,
) -> std::io::Result<()> {
match &self.0 {
Source::Empty => Ok(()),
Source::Bytes(bytes) => sink.write_all(bytes).await,
Source::File(path) => {
let mut file = tokio::fs::File::open(path).await?;
tokio::io::copy(&mut file, sink).await.map(|_| ())
}
Source::Reader(reader) => {
let mut guard = reader.lock().await;
match guard.take() {
Some(mut r) => tokio::io::copy(&mut r, sink).await.map(|_| ()),
None => Ok(()), }
}
Source::Lines(lines) => {
let mut guard = lines.lock().await;
match guard.take() {
Some(mut stream) => {
while let Some(line) = stream.next().await {
sink.write_all(line.as_bytes()).await?;
sink.write_all(b"\n").await?;
}
Ok(())
}
None => Ok(()),
}
}
}
}
}
impl fmt::Debug for Stdin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let kind = match &self.0 {
Source::Empty => "Empty",
Source::Bytes(_) => "Bytes",
Source::File(_) => "File",
Source::Reader(_) => "Reader",
Source::Lines(_) => "Lines",
};
f.debug_tuple("Stdin").field(&kind).finish()
}
}
pub struct ProcessStdin {
sink: tokio::process::ChildStdin,
}
impl ProcessStdin {
pub(crate) fn new(sink: tokio::process::ChildStdin) -> Self {
Self { sink }
}
pub async fn write(&mut self, bytes: &[u8]) -> std::io::Result<()> {
self.sink.write_all(bytes).await
}
pub async fn write_line(&mut self, line: &str) -> std::io::Result<()> {
self.sink.write_all(line.as_bytes()).await?;
self.sink.write_all(b"\n").await?;
self.sink.flush().await
}
pub async fn flush(&mut self) -> std::io::Result<()> {
self.sink.flush().await
}
pub async fn finish(mut self) -> std::io::Result<()> {
self.sink.shutdown().await
}
}
impl fmt::Debug for ProcessStdin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProcessStdin").finish_non_exhaustive()
}
}