use crate::channel::{channel, Receiver, Sender};
use crate::process::message::Message;
use nix::unistd::Pid;
use std::os::unix::prelude::{AsRawFd, RawFd};
#[derive(Debug, thiserror::Error)]
pub enum ChannelError {
#[error("received unexpected message: {received:?}, expected: {expected:?}")]
UnexpectedMessage {
expected: Message,
received: Message,
},
#[error("failed to receive. {msg:?}. {source:?}")]
ReceiveError {
msg: String,
#[source]
source: crate::channel::ChannelError,
},
#[error(transparent)]
BaseChannelError(#[from] crate::channel::ChannelError),
#[error("missing fds from seccomp request")]
MissingSeccompFds,
#[error("exec process failed with error {0}")]
ExecError(String),
}
pub fn main_channel() -> Result<(MainSender, MainReceiver), ChannelError> {
let (sender, receiver) = channel::<Message>()?;
Ok((MainSender { sender }, MainReceiver { receiver }))
}
#[derive(Clone)]
pub struct MainSender {
sender: Sender<Message>,
}
impl MainSender {
pub fn identifier_mapping_request(&mut self) -> Result<(), ChannelError> {
tracing::debug!("send identifier mapping request");
self.sender.send(Message::WriteMapping)?;
Ok(())
}
pub fn seccomp_notify_request(&mut self, fd: RawFd) -> Result<(), ChannelError> {
self.sender
.send_fds(Message::SeccompNotify, &[fd.as_raw_fd()])?;
Ok(())
}
pub fn intermediate_ready(&mut self, pid: Pid) -> Result<(), ChannelError> {
tracing::debug!("sending init pid ({:?})", pid);
self.sender.send(Message::IntermediateReady(pid.as_raw()))?;
Ok(())
}
pub fn init_ready(&mut self) -> Result<(), ChannelError> {
self.sender.send(Message::InitReady)?;
Ok(())
}
pub fn exec_failed(&mut self, err: String) -> Result<(), ChannelError> {
self.sender.send(Message::ExecFailed(err))?;
Ok(())
}
pub fn close(&self) -> Result<(), ChannelError> {
self.sender.close()?;
Ok(())
}
}
#[derive(Clone)]
pub struct MainReceiver {
receiver: Receiver<Message>,
}
impl MainReceiver {
pub fn wait_for_intermediate_ready(&mut self) -> Result<Pid, ChannelError> {
let msg = self
.receiver
.recv()
.map_err(|err| ChannelError::ReceiveError {
msg: "waiting for intermediate process".to_string(),
source: err,
})?;
match msg {
Message::IntermediateReady(pid) => Ok(Pid::from_raw(pid)),
Message::ExecFailed(err) => Err(ChannelError::ExecError(err)),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::IntermediateReady(0),
received: msg,
}),
}
}
pub fn wait_for_mapping_request(&mut self) -> Result<(), ChannelError> {
let msg = self
.receiver
.recv()
.map_err(|err| ChannelError::ReceiveError {
msg: "waiting for mapping request".to_string(),
source: err,
})?;
match msg {
Message::WriteMapping => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::WriteMapping,
received: msg,
}),
}
}
pub fn wait_for_seccomp_request(&mut self) -> Result<i32, ChannelError> {
let (msg, fds) = self.receiver.recv_with_fds::<[RawFd; 1]>().map_err(|err| {
ChannelError::ReceiveError {
msg: "waiting for seccomp request".to_string(),
source: err,
}
})?;
match msg {
Message::SeccompNotify => {
let fd = match fds {
Some(fds) => {
if fds.is_empty() {
Err(ChannelError::MissingSeccompFds)
} else {
Ok(fds[0])
}
}
None => Err(ChannelError::MissingSeccompFds),
}?;
Ok(fd)
}
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::SeccompNotify,
received: msg,
}),
}
}
pub fn wait_for_init_ready(&mut self) -> Result<(), ChannelError> {
let msg = self
.receiver
.recv()
.map_err(|err| ChannelError::ReceiveError {
msg: "waiting for init ready".to_string(),
source: err,
})?;
match msg {
Message::InitReady => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::InitReady,
received: msg,
}),
}
}
pub fn close(&self) -> Result<(), ChannelError> {
self.receiver.close()?;
Ok(())
}
}
pub fn intermediate_channel() -> Result<(IntermediateSender, IntermediateReceiver), ChannelError> {
let (sender, receiver) = channel::<Message>()?;
Ok((
IntermediateSender { sender },
IntermediateReceiver { receiver },
))
}
#[derive(Clone)]
pub struct IntermediateSender {
sender: Sender<Message>,
}
impl IntermediateSender {
pub fn mapping_written(&mut self) -> Result<(), ChannelError> {
tracing::debug!("identifier mapping written");
self.sender.send(Message::MappingWritten)?;
Ok(())
}
pub fn close(&self) -> Result<(), ChannelError> {
self.sender.close()?;
Ok(())
}
}
#[derive(Clone)]
pub struct IntermediateReceiver {
receiver: Receiver<Message>,
}
impl IntermediateReceiver {
pub fn wait_for_mapping_ack(&mut self) -> Result<(), ChannelError> {
tracing::debug!("waiting for mapping ack");
let msg = self
.receiver
.recv()
.map_err(|err| ChannelError::ReceiveError {
msg: "waiting for mapping ack".to_string(),
source: err,
})?;
match msg {
Message::MappingWritten => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::MappingWritten,
received: msg,
}),
}
}
pub fn close(&self) -> Result<(), ChannelError> {
self.receiver.close()?;
Ok(())
}
}
pub fn init_channel() -> Result<(InitSender, InitReceiver), ChannelError> {
let (sender, receiver) = channel::<Message>()?;
Ok((InitSender { sender }, InitReceiver { receiver }))
}
#[derive(Clone)]
pub struct InitSender {
sender: Sender<Message>,
}
impl InitSender {
pub fn seccomp_notify_done(&mut self) -> Result<(), ChannelError> {
self.sender.send(Message::SeccompNotifyDone)?;
Ok(())
}
pub fn close(&self) -> Result<(), ChannelError> {
self.sender.close()?;
Ok(())
}
}
#[derive(Clone)]
pub struct InitReceiver {
receiver: Receiver<Message>,
}
impl InitReceiver {
pub fn wait_for_seccomp_request_done(&mut self) -> Result<(), ChannelError> {
let msg = self
.receiver
.recv()
.map_err(|err| ChannelError::ReceiveError {
msg: "waiting for seccomp request".to_string(),
source: err,
})?;
match msg {
Message::SeccompNotifyDone => Ok(()),
msg => Err(ChannelError::UnexpectedMessage {
expected: Message::SeccompNotifyDone,
received: msg,
}),
}
}
pub fn close(&self) -> Result<(), ChannelError> {
self.receiver.close()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::{Context, Result};
use nix::sys::wait;
use nix::unistd;
use serial_test::serial;
#[test]
#[serial]
fn test_channel_intermadiate_ready() -> Result<()> {
let (sender, receiver) = &mut main_channel()?;
match unsafe { unistd::fork()? } {
unistd::ForkResult::Parent { child } => {
wait::waitpid(child, None)?;
let pid = receiver
.wait_for_intermediate_ready()
.with_context(|| "Failed to wait for intermadiate ready")?;
receiver.close()?;
assert_eq!(pid, child);
}
unistd::ForkResult::Child => {
let pid = unistd::getpid();
sender.intermediate_ready(pid)?;
sender.close()?;
std::process::exit(0);
}
};
Ok(())
}
#[test]
#[serial]
fn test_channel_id_mapping_request() -> Result<()> {
let (sender, receiver) = &mut main_channel()?;
match unsafe { unistd::fork()? } {
unistd::ForkResult::Parent { child } => {
wait::waitpid(child, None)?;
receiver.wait_for_mapping_request()?;
receiver.close()?;
}
unistd::ForkResult::Child => {
sender
.identifier_mapping_request()
.with_context(|| "Failed to send mapping written")?;
sender.close()?;
std::process::exit(0);
}
};
Ok(())
}
#[test]
#[serial]
fn test_channel_id_mapping_ack() -> Result<()> {
let (sender, receiver) = &mut intermediate_channel()?;
match unsafe { unistd::fork()? } {
unistd::ForkResult::Parent { child } => {
wait::waitpid(child, None)?;
receiver.wait_for_mapping_ack()?;
}
unistd::ForkResult::Child => {
sender
.mapping_written()
.with_context(|| "Failed to send mapping written")?;
std::process::exit(0);
}
};
Ok(())
}
#[test]
#[serial]
fn test_channel_init_ready() -> Result<()> {
let (sender, receiver) = &mut main_channel()?;
match unsafe { unistd::fork()? } {
unistd::ForkResult::Parent { child } => {
wait::waitpid(child, None)?;
receiver.wait_for_init_ready()?;
receiver.close()?;
}
unistd::ForkResult::Child => {
sender
.init_ready()
.with_context(|| "Failed to send init ready")?;
sender.close()?;
std::process::exit(0);
}
};
Ok(())
}
#[test]
#[serial]
fn test_channel_main_graceful_exit() -> Result<()> {
let (sender, receiver) = &mut main_channel()?;
match unsafe { unistd::fork()? } {
unistd::ForkResult::Parent { child } => {
sender.close().context("failed to close sender")?;
let ret = receiver.wait_for_intermediate_ready();
assert!(ret.is_err());
wait::waitpid(child, None)?;
}
unistd::ForkResult::Child => {
receiver.close()?;
std::process::exit(0);
}
};
Ok(())
}
#[test]
#[serial]
fn test_channel_intermediate_graceful_exit() -> Result<()> {
let (sender, receiver) = &mut main_channel()?;
match unsafe { unistd::fork()? } {
unistd::ForkResult::Parent { child } => {
sender.close().context("failed to close sender")?;
let ret = receiver.wait_for_init_ready();
assert!(ret.is_err());
wait::waitpid(child, None)?;
}
unistd::ForkResult::Child => {
receiver.close()?;
std::process::exit(0);
}
};
Ok(())
}
}