use std::sync::{Arc, Mutex};
use std::io::{self, Read, Write};
use crate::error::Result;
use crate::communication::CommunicationChannel;
pub struct StdioRedirection {
stdin_channel: Arc<dyn CommunicationChannel>,
stdout_channel: Arc<dyn CommunicationChannel>,
stderr_channel: Arc<dyn CommunicationChannel>,
stdin_buffer: Mutex<Vec<u8>>,
closed: Mutex<bool>,
}
impl StdioRedirection {
pub fn new(
stdin_channel: Arc<dyn CommunicationChannel>,
stdout_channel: Arc<dyn CommunicationChannel>,
stderr_channel: Arc<dyn CommunicationChannel>,
) -> Self {
Self {
stdin_channel,
stdout_channel,
stderr_channel,
stdin_buffer: Mutex::new(Vec::new()),
closed: Mutex::new(false),
}
}
pub fn write_stdout(&self, data: &[u8]) -> Result<()> {
self.stdout_channel.send_to_guest(data)
}
pub fn write_stderr(&self, data: &[u8]) -> Result<()> {
self.stderr_channel.send_to_guest(data)
}
pub fn read_stdin(&self, buf: &mut [u8]) -> Result<usize> {
if *self.closed.lock().unwrap() {
return Ok(0);
}
let mut stdin_buffer = self.stdin_buffer.lock().unwrap();
if stdin_buffer.is_empty() {
if let Ok(data) = self.stdin_channel.receive_from_guest() {
stdin_buffer.extend_from_slice(&data);
}
}
if !stdin_buffer.is_empty() {
let n = std::cmp::min(buf.len(), stdin_buffer.len());
buf[..n].copy_from_slice(&stdin_buffer[..n]);
stdin_buffer.drain(..n);
Ok(n)
} else {
Ok(0)
}
}
pub fn close(&self) -> Result<()> {
let mut closed = self.closed.lock().unwrap();
*closed = true;
self.stdin_channel.close()?;
self.stdout_channel.close()?;
self.stderr_channel.close()?;
Ok(())
}
}
pub struct StdioInput {
redirection: Arc<StdioRedirection>,
}
impl StdioInput {
pub fn new(redirection: Arc<StdioRedirection>) -> Self {
Self {
redirection,
}
}
}
impl Read for StdioInput {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.redirection.read_stdin(buf)
.map_err(|e| io::Error::other(format!("{:?}", e)))
}
}
pub struct StdioOutput {
redirection: Arc<StdioRedirection>,
is_stderr: bool,
}
impl StdioOutput {
pub fn new_stdout(redirection: Arc<StdioRedirection>) -> Self {
Self {
redirection,
is_stderr: false,
}
}
pub fn new_stderr(redirection: Arc<StdioRedirection>) -> Self {
Self {
redirection,
is_stderr: true,
}
}
}
impl Write for StdioOutput {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let result = if self.is_stderr {
self.redirection.write_stderr(buf)
} else {
self.redirection.write_stdout(buf)
};
match result {
Ok(_) => Ok(buf.len()),
Err(e) => Err(io::Error::other(format!("{:?}", e))),
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct StdioFactory {
channel_factory: Arc<dyn crate::communication::CommunicationFactory>,
}
impl StdioFactory {
pub fn new(channel_factory: Arc<dyn crate::communication::CommunicationFactory>) -> Self {
Self {
channel_factory,
}
}
pub fn create_redirection(&self) -> Result<Arc<StdioRedirection>> {
let stdin_channel = self.channel_factory.create_channel()?;
let stdout_channel = self.channel_factory.create_channel()?;
let stderr_channel = self.channel_factory.create_channel()?;
let redirection = StdioRedirection::new(stdin_channel, stdout_channel, stderr_channel);
Ok(Arc::new(redirection))
}
}