use std::fs::File;
use std::os::fd::{AsRawFd, RawFd};
use ipc_channel::ipc::{self, IpcError, IpcReceiver, IpcSender};
use nix::errno::errno;
use nix::unistd::{close, dup2, fork, ForkResult};
use serde::{Deserialize, Serialize};
use crate::shm::create_shm;
use crate::{bash, Error};
pub fn redirect_output(fd: RawFd) -> crate::Result<()> {
dup2(fd, 1)?;
dup2(fd, 2)?;
close(fd)?;
Ok(())
}
pub fn suppress_output() -> crate::Result<()> {
let f = File::options().write(true).open("/dev/null")?;
redirect_output(f.as_raw_fd())?;
Ok(())
}
struct SharedSemaphore {
sem: *mut libc::sem_t,
size: u32,
}
impl SharedSemaphore {
fn new(size: usize) -> crate::Result<Self> {
let ptr = create_shm("scallop-pool-sem", std::mem::size_of::<libc::sem_t>())?;
let sem = ptr as *mut libc::sem_t;
let size: u32 = size
.try_into()
.map_err(|_| Error::Base(format!("pool too large: {size}")))?;
if unsafe { libc::sem_init(sem, 1, size) } == 0 {
Ok(Self { sem, size })
} else {
let err = errno();
Err(Error::Base(format!("sem_init() failed: {err}")))
}
}
fn acquire(&mut self) -> crate::Result<()> {
if unsafe { libc::sem_wait(self.sem) } == 0 {
Ok(())
} else {
let err = errno();
Err(Error::Base(format!("sem_wait() failed: {err}")))
}
}
fn release(&mut self) -> crate::Result<()> {
if unsafe { libc::sem_post(self.sem) } == 0 {
Ok(())
} else {
let err = errno();
Err(Error::Base(format!("sem_post() failed: {err}")))
}
}
fn wait(&mut self) -> crate::Result<()> {
for _ in 0..self.size {
self.acquire()?;
}
Ok(())
}
}
impl Drop for SharedSemaphore {
fn drop(&mut self) {
unsafe { libc::sem_destroy(self.sem) };
}
}
pub struct PoolIter<T: Serialize + for<'a> Deserialize<'a>> {
rx: IpcReceiver<T>,
}
impl<T: Serialize + for<'a> Deserialize<'a>> PoolIter<T> {
pub fn new<O, I, F>(size: usize, iter: I, func: F, suppress: bool) -> crate::Result<Self>
where
I: Iterator<Item = O>,
F: FnOnce(O) -> T,
{
unsafe { bash::set_sigchld_handler() };
let mut sem = SharedSemaphore::new(size)?;
let (tx, rx): (IpcSender<T>, IpcReceiver<T>) =
ipc::channel().map_err(|e| Error::Base(format!("failed creating IPC channel: {e}")))?;
match unsafe { fork() } {
Ok(ForkResult::Parent { .. }) => Ok(()),
Ok(ForkResult::Child) => {
if suppress {
suppress_output()?;
}
for obj in iter {
sem.acquire()?;
match unsafe { fork() } {
Ok(ForkResult::Parent { .. }) => (),
Ok(ForkResult::Child) => {
let r = func(obj);
tx.send(r).map_err(|e| {
Error::Base(format!("process pool sending failed: {e}"))
})?;
sem.release()?;
unsafe { libc::_exit(0) };
}
Err(e) => panic!("process pool fork failed: {e}"),
}
}
unsafe { libc::_exit(0) };
}
Err(e) => Err(Error::Base(format!("starting process pool failed: {e}"))),
}?;
Ok(Self { rx })
}
}
impl<T: Serialize + for<'a> Deserialize<'a>> Iterator for PoolIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.rx.recv() {
Ok(r) => Some(r),
Err(IpcError::Disconnected) => None,
Err(e) => panic!("process pool receiver failed: {e}"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
enum Msg<T> {
Val(T),
Stop,
}
pub struct PoolSendIter<I, O>
where
I: Serialize + for<'a> Deserialize<'a>,
O: Serialize + for<'a> Deserialize<'a>,
{
input_tx: IpcSender<Msg<I>>,
output_rx: IpcReceiver<Msg<O>>,
}
impl<I, O> PoolSendIter<I, O>
where
I: Serialize + for<'a> Deserialize<'a>,
O: Serialize + for<'a> Deserialize<'a>,
{
pub fn new<F>(size: usize, func: F, suppress: bool) -> crate::Result<Self>
where
F: FnOnce(I) -> O,
{
unsafe { bash::set_sigchld_handler() };
let mut sem = SharedSemaphore::new(size)?;
let (input_tx, input_rx): (IpcSender<Msg<I>>, IpcReceiver<Msg<I>>) = ipc::channel()
.map_err(|e| Error::Base(format!("failed creating input channel: {e}")))?;
let (output_tx, output_rx): (IpcSender<Msg<O>>, IpcReceiver<Msg<O>>) = ipc::channel()
.map_err(|e| Error::Base(format!("failed creating output channel: {e}")))?;
match unsafe { fork() } {
Ok(ForkResult::Parent { .. }) => Ok(()),
Ok(ForkResult::Child) => {
if suppress {
suppress_output()?;
}
while let Ok(Msg::Val(obj)) = input_rx.recv() {
sem.acquire()?;
match unsafe { fork() } {
Ok(ForkResult::Parent { .. }) => (),
Ok(ForkResult::Child) => {
let r = func(obj);
output_tx.send(Msg::Val(r)).map_err(|e| {
Error::Base(format!("process pool failed send: {e}"))
})?;
sem.release()?;
unsafe { libc::_exit(0) };
}
Err(e) => panic!("process pool fork failed: {e}"),
}
}
sem.wait()?;
output_tx
.send(Msg::Stop)
.map_err(|e| Error::Base(format!("process pool failed stop: {e}")))?;
unsafe { libc::_exit(0) }
}
Err(e) => Err(Error::Base(format!("process pool failed start: {e}"))),
}?;
Ok(Self { input_tx, output_rx })
}
pub fn iter<V: Iterator<Item = I>>(&self, vals: V) -> crate::Result<PoolReceiveIter<O>> {
match unsafe { fork() } {
Ok(ForkResult::Parent { .. }) => Ok(()),
Ok(ForkResult::Child) => {
for val in vals {
self.input_tx
.send(Msg::Val(val))
.map_err(|e| Error::Base(format!("failed queuing value: {e}")))?;
}
self.input_tx
.send(Msg::Stop)
.map_err(|e| Error::Base(format!("failed stopping workers: {e}")))?;
unsafe { libc::_exit(0) };
}
Err(e) => Err(Error::Base(format!("failed starting queuing process: {e}"))),
}?;
Ok(PoolReceiveIter { rx: &self.output_rx })
}
}
impl<I, O> Drop for PoolSendIter<I, O>
where
I: Serialize + for<'a> Deserialize<'a>,
O: Serialize + for<'a> Deserialize<'a>,
{
fn drop(&mut self) {
self.input_tx.send(Msg::Stop).ok();
}
}
pub struct PoolReceiveIter<'p, T>
where
T: Serialize + for<'a> Deserialize<'a>,
{
rx: &'p IpcReceiver<Msg<T>>,
}
impl<T> Iterator for PoolReceiveIter<'_, T>
where
T: Serialize + for<'a> Deserialize<'a>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.rx.recv() {
Ok(Msg::Val(r)) => Some(r),
Ok(Msg::Stop) => None,
Err(e) => panic!("output receiver failed: {e}"),
}
}
}
#[cfg(test)]
mod tests {
use crate::source;
use crate::variables::optional;
use super::*;
#[test]
fn env_leaking() {
assert!(optional("VAR").is_none());
let vals: Vec<_> = (0..16).collect();
let func = |i: u64| {
source::string(format!("VAR={i}")).unwrap();
assert_eq!(optional("VAR").unwrap(), i.to_string());
i
};
PoolIter::new(2, vals.into_iter(), func, false)
.unwrap()
.for_each(drop);
assert!(optional("VAR").is_none());
}
}