extern crate nix;
use super::{ShellError};
use std::path::PathBuf;
use std::os::unix::io::RawFd;
use std::time::{Instant, Duration};
use nix::unistd;
#[derive(Clone, std::fmt::Debug)]
pub(crate) struct Pipe {
pub path: PathBuf, pub fd: RawFd
}
impl Pipe {
pub fn open(path: &PathBuf) -> Result<Pipe, ShellError> {
if let Err(err) = unistd::mkfifo(path.as_path(), nix::sys::stat::Mode::S_IRWXU | nix::sys::stat::Mode::S_IRWXG | nix::sys::stat::Mode::S_IRWXO) {
match err {
nix::Error::Sys(errno) => return Err(ShellError::PipeError(errno)),
_ => return Err(ShellError::PipeError(nix::errno::Errno::UnknownErrno))
}
}
match nix::fcntl::open(path.as_path(), nix::fcntl::OFlag::O_RDWR, nix::sys::stat::Mode::S_IRWXU | nix::sys::stat::Mode::S_IRWXG | nix::sys::stat::Mode::S_IRWXO) {
Ok(fd) => {
Ok(Pipe {
path: path.clone(),
fd: fd
})
},
Err(err) => {
match err {
nix::Error::Sys(errno) => Err(ShellError::PipeError(errno)),
_ => Err(ShellError::PipeError(nix::errno::Errno::UnknownErrno))
}
}
}
}
pub fn close(&self) -> Result<(), ShellError> {
if let Err(err) = unistd::close(self.fd) {
match err {
nix::Error::Sys(errno) => return Err(ShellError::PipeError(errno)),
_ => return Err(ShellError::PipeError(nix::errno::Errno::UnknownErrno))
}
};
let _ = unistd::unlink(self.path.as_path());
Ok(())
}
pub fn read(&self, timeout: u64, read_all: bool) -> Result<Option<String>, ShellError> {
let mut poll_fds: [nix::poll::PollFd; 1] = [nix::poll::PollFd::new(self.fd, nix::poll::PollFlags::POLLIN | nix::poll::PollFlags::POLLRDBAND | nix::poll::PollFlags::POLLHUP)];
let mut data_out: String = String::new();
let mut data_size: usize = 0;
let timeout: Duration = Duration::from_millis(timeout);
let time: Instant = Instant::now();
while time.elapsed() < timeout {
match nix::poll::poll(&mut poll_fds, 50) {
Ok(ret) => {
if ret > 0 && poll_fds[0].revents().is_some() { let event: nix::poll::PollFlags = poll_fds[0].revents().unwrap();
if event.intersects(nix::poll::PollFlags::POLLIN) || event.intersects(nix::poll::PollFlags::POLLRDBAND) {
let mut buffer: [u8; 8192] = [0; 8192];
match unistd::read(self.fd, &mut buffer) {
Ok(bytes_read) => {
data_size += bytes_read;
data_out.push_str(match std::str::from_utf8(&buffer[0..bytes_read]) {
Ok(s) => s,
Err(_) => {
return Err(ShellError::InvalidData)
}
});
if ! read_all {
break;
}
},
Err(err) => {
match err {
nix::Error::Sys(errno) => {
match errno {
nix::errno::Errno::EAGAIN => { if data_size == 0 {
continue; } else {
break; }
},
_ => return Err(ShellError::PipeError(errno)) }
},
_ => return Err(ShellError::PipeError(nix::errno::Errno::UnknownErrno))
}
}
};
} else if event.intersects(nix::poll::PollFlags::POLLERR) { return Err(ShellError::PipeError(nix::errno::Errno::EPIPE))
} else if event.intersects(nix::poll::PollFlags::POLLHUP) { if data_size == 0 {
continue;
} else {
break;
}
}
} else if ret == 0 {
if data_size == 0 {
continue;
} else {
break;
}
}
},
Err(err) => { match err {
nix::Error::Sys(errno) => {
match errno {
nix::errno::Errno::EAGAIN => { if data_size == 0 {
continue; } else {
break; }
},
_ => return Err(ShellError::PipeError(errno)) }
},
_ => return Err(ShellError::PipeError(nix::errno::Errno::UnknownErrno))
}
}
}
}
match data_size {
0 => Ok(None),
_ => Ok(Some(data_out))
}
}
pub fn write(&self, data: String, timeout: u64) -> Result<(), ShellError> {
let mut poll_fds: [nix::poll::PollFd; 1] = [nix::poll::PollFd::new(self.fd, nix::poll::PollFlags::POLLOUT)];
let timeout: Duration = Duration::from_millis(timeout);
let time: Instant = Instant::now();
let data_out = data.as_bytes();
let total_bytes_amount: usize = data_out.len();
let mut bytes_written: usize = 0;
while bytes_written < total_bytes_amount {
match nix::poll::poll(&mut poll_fds, 50) {
Ok(_) => {
if let Some(revents) = poll_fds[0].revents() {
if revents.intersects(nix::poll::PollFlags::POLLOUT) {
let bytes_out = if total_bytes_amount - bytes_written > 8192 {
8192
} else {
total_bytes_amount - bytes_written
};
match unistd::write(self.fd, &data_out[bytes_written..(bytes_written + bytes_out)]) {
Ok(bytes) => {
bytes_written += bytes; },
Err(err) => {
match err {
nix::Error::Sys(errno) => return Err(ShellError::PipeError(errno)),
_ => return Err(ShellError::PipeError(nix::errno::Errno::UnknownErrno))
}
}
}
}
}
},
Err(err) => {
match err {
nix::Error::Sys(errno) => return Err(ShellError::PipeError(errno)),
_ => return Err(ShellError::PipeError(nix::errno::Errno::UnknownErrno))
}
}
};
if bytes_written == 0 && time.elapsed() >= timeout {
return Err(ShellError::IoTimeout);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_pipe_open_close() {
let tmpdir: tempfile::TempDir = create_tmp_dir();
let pipe_path: PathBuf = tmpdir.path().join("test.fifo");
let pipe: Result<Pipe, ShellError> = Pipe::open(&pipe_path);
assert!(pipe.is_ok(), format!("Pipe ({}) should be OK, but is {:?}", pipe_path.display(), pipe));
let pipe: Pipe = pipe.unwrap();
assert_eq!(pipe.path, pipe_path);
assert!(pipe.fd > 0);
assert!(pipe.close().is_ok());
}
#[test]
fn test_pipe_io() {
let tmpdir: tempfile::TempDir = create_tmp_dir();
let pipe_path: PathBuf = tmpdir.path().join("stdout.fifo");
let pipe: Result<Pipe, ShellError> = Pipe::open(&pipe_path);
assert!(pipe.is_ok(), format!("Pipe ({}) should be OK, but is {:?}", pipe_path.display(), pipe));
let pipe: Pipe = pipe.unwrap();
let pipe_thread: Pipe = pipe.clone();
let join_hnd: thread::JoinHandle<()> = thread::spawn(move || {
let input: String = pipe_thread.read(1000, true).unwrap().unwrap();
assert_eq!(input, String::from("HELLO\n"));
thread::sleep(Duration::from_millis(100)); assert!(pipe_thread.write(String::from("HI THERE\n"), 1000).is_ok());
});
assert!(pipe.write(String::from("HELLO\n"), 1000).is_ok(), "Write timeout");
thread::sleep(Duration::from_millis(100)); let read: Result<Option<String>, ShellError> = pipe.read(1000, true);
assert!(read.is_ok(), format!("Read should be Ok, but is {:?}", read));
let read: Option<String> = read.unwrap();
assert_eq!(read.unwrap(), String::from("HI THERE\n"));
assert!(join_hnd.join().is_ok());
assert!(pipe.close().is_ok());
}
#[test]
fn test_pipe_read_all() {
let tmpdir: tempfile::TempDir = create_tmp_dir();
let pipe_path: PathBuf = tmpdir.path().join("stdout.fifo");
let pipe: Result<Pipe, ShellError> = Pipe::open(&pipe_path);
assert!(pipe.is_ok(), format!("Pipe ({}) should be OK, but is {:?}", pipe_path.display(), pipe));
let pipe: Pipe = pipe.unwrap();
let pipe_thread: Pipe = pipe.clone();
let join_hnd: thread::JoinHandle<()> = thread::spawn(move || {
let mut data: String = String::with_capacity(10240);
for _ in 0..10240 {
data.push('c');
}
assert!(pipe_thread.write(data.clone(), 1000).is_ok());
thread::sleep(Duration::from_millis(500)); assert!(pipe_thread.write(data, 1000).is_ok());
});
assert_eq!(pipe.read(500, true).unwrap().unwrap().len(), 10240);
thread::sleep(Duration::from_millis(500)); assert_eq!(pipe.read(500, false).unwrap().unwrap().len(), 8192);
assert_eq!(pipe.read(500, false).unwrap().unwrap().len(), 2048);
assert!(join_hnd.join().is_ok());
assert!(pipe.close().is_ok());
}
#[test]
fn test_pipe_open_close_error() {
let pipe_path: PathBuf = PathBuf::from("/dev/tty1");
let pipe: Result<Pipe, ShellError> = Pipe::open(&pipe_path);
assert!(pipe.is_err());
let pipe: Pipe = Pipe {
fd: 10,
path: PathBuf::from("/tmp/stdout.fifo")
};
assert!(pipe.close().is_err());
}
#[test]
fn test_pipe_io_error() {
let tmpdir: tempfile::TempDir = create_tmp_dir();
let pipe_path: PathBuf = tmpdir.path().join("stdout.fifo");
let pipe: Result<Pipe, ShellError> = Pipe::open(&pipe_path);
assert!(pipe.is_ok(), format!("Pipe ({}) should be OK, but is {:?}", pipe_path.display(), pipe));
let pipe: Pipe = pipe.unwrap();
assert!(pipe.read(1000, true).unwrap().is_none(), "Read should be None");
assert!(pipe.close().is_ok());
}
fn create_tmp_dir() -> tempfile::TempDir {
tempfile::TempDir::new().unwrap()
}
}