sudo-rs 0.2.13

A memory safe implementation of sudo and su.
Documentation
use std::{
    ffi::c_int,
    io,
    mem::size_of,
    os::fd::{AsFd, BorrowedFd},
};

use crate::{
    common::bin_serde::{BinPipe, DeSerialize},
    exec::signal_fmt,
    system::interface::ProcessId,
};

use super::CommandStatus;

type Prefix = u8;
type ParentData = c_int;
type MonitorData = c_int;

const PREFIX_LEN: usize = size_of::<Prefix>();
const PARENT_DATA_LEN: usize = size_of::<ParentData>();
const MONITOR_DATA_LEN: usize = size_of::<MonitorData>();

pub(super) struct BackchannelPair {
    pub(super) parent: ParentBackchannel,
    pub(super) monitor: MonitorBackchannel,
}

impl BackchannelPair {
    pub(super) fn new() -> io::Result<Self> {
        let (sock1, sock2) = BinPipe::pair()?;

        #[cfg(debug_assertions)]
        {
            sock1.set_nonblocking(true)?;
            sock2.set_nonblocking(true)?;
        }

        Ok(Self {
            parent: ParentBackchannel {
                socket: sock1,
                #[cfg(debug_assertions)]
                nonblocking_asserts: false,
            },
            monitor: MonitorBackchannel {
                socket: sock2,
                #[cfg(debug_assertions)]
                nonblocking_asserts: false,
            },
        })
    }
}

pub(super) enum ParentMessage {
    IoError(c_int),
    CommandStatus(CommandStatus),
    CommandPid(ProcessId),
    ShortRead,
}

impl ParentMessage {
    const LEN: usize = PREFIX_LEN + PARENT_DATA_LEN;
    const IO_ERROR: Prefix = 0;
    const CMD_STAT_EXIT: Prefix = 1;
    const CMD_STAT_TERM: Prefix = 2;
    const CMD_STAT_STOP: Prefix = 3;
    const CMD_PID: Prefix = 4;
    const SHORT_READ: Prefix = 5;

    fn from_parts(prefix: Prefix, data: ParentData) -> Self {
        match prefix {
            Self::IO_ERROR => Self::IoError(data),
            Self::CMD_STAT_EXIT => Self::CommandStatus(CommandStatus::Exit(data)),
            Self::CMD_STAT_TERM => Self::CommandStatus(CommandStatus::Term(data)),
            Self::CMD_STAT_STOP => Self::CommandStatus(CommandStatus::Stop(data)),
            Self::CMD_PID => Self::CommandPid(ProcessId::new(data)),
            Self::SHORT_READ => Self::ShortRead,
            _ => unreachable!(),
        }
    }

    fn to_parts(&self) -> (Prefix, ParentData) {
        let prefix = match self {
            ParentMessage::IoError(_) => Self::IO_ERROR,
            ParentMessage::CommandStatus(CommandStatus::Exit(_)) => Self::CMD_STAT_EXIT,
            ParentMessage::CommandStatus(CommandStatus::Term(_)) => Self::CMD_STAT_TERM,
            ParentMessage::CommandStatus(CommandStatus::Stop(_)) => Self::CMD_STAT_STOP,
            ParentMessage::CommandPid(_) => Self::CMD_PID,
            ParentMessage::ShortRead => Self::SHORT_READ,
        };

        let data = match self {
            ParentMessage::IoError(data) => *data,
            ParentMessage::CommandPid(data) => data.inner(),
            ParentMessage::CommandStatus(status) => match status {
                CommandStatus::Exit(data)
                | CommandStatus::Term(data)
                | CommandStatus::Stop(data) => *data,
            },
            ParentMessage::ShortRead => 0,
        };

        (prefix, data)
    }
}

impl TryFrom<io::Error> for ParentMessage {
    type Error = io::Error;

    fn try_from(err: io::Error) -> Result<Self, Self::Error> {
        err.raw_os_error()
            .map(Self::IoError)
            .or_else(|| (err.kind() == io::ErrorKind::UnexpectedEof).then_some(Self::ShortRead))
            .ok_or(err)
    }
}

impl From<CommandStatus> for ParentMessage {
    fn from(status: CommandStatus) -> Self {
        Self::CommandStatus(status)
    }
}

impl DeSerialize for ParentMessage {
    type Bytes = [u8; ParentMessage::LEN];

    fn serialize(&self) -> Self::Bytes {
        let mut buf = [0; ParentMessage::LEN];

        let (prefix_buf, data_buf) = buf.split_at_mut(PREFIX_LEN);
        let (prefix, data) = self.to_parts();

        prefix_buf.copy_from_slice(&prefix.to_ne_bytes());
        data_buf.copy_from_slice(&data.to_ne_bytes());
        buf
    }

    fn deserialize(buf: Self::Bytes) -> Self {
        let (prefix_buf, data_buf) = buf.split_at(PREFIX_LEN);

        let prefix = Prefix::from_ne_bytes(prefix_buf.try_into().unwrap());
        let data = MonitorData::from_ne_bytes(data_buf.try_into().unwrap());

        ParentMessage::from_parts(prefix, data)
    }
}

/// A socket use for communication between the monitor and the parent process.
pub(super) struct ParentBackchannel {
    socket: BinPipe<ParentMessage, MonitorMessage>,
    #[cfg(debug_assertions)]
    nonblocking_asserts: bool,
}

impl ParentBackchannel {
    /// Send a [`MonitorMessage`].
    ///
    /// Calling this method will block until the socket is ready for writing.
    #[track_caller]
    pub(super) fn send(&mut self, event: &MonitorMessage) -> io::Result<()> {
        self.socket.write(event).inspect_err(|_err| {
            #[cfg(debug_assertions)]
            if self.nonblocking_asserts {
                assert_ne!(_err.kind(), io::ErrorKind::WouldBlock);
            }
        })
    }

    /// Receive a [`ParentMessage`].
    ///
    /// Calling this method will block until the socket is ready for reading.
    #[track_caller]
    pub(super) fn recv(&mut self) -> io::Result<ParentMessage> {
        let msg = self.socket.read().inspect_err(|_err| {
            #[cfg(debug_assertions)]
            if self.nonblocking_asserts {
                assert_ne!(_err.kind(), io::ErrorKind::WouldBlock);
            }
        })?;
        Ok(msg)
    }

    pub(super) fn set_nonblocking_asserts(&mut self, _doit: bool) {
        #[cfg(debug_assertions)]
        {
            self.nonblocking_asserts = _doit;
        }
    }
}

impl AsFd for ParentBackchannel {
    fn as_fd(&self) -> BorrowedFd<'_> {
        self.socket.as_fd()
    }
}

/// Different messages exchanged between the monitor and the parent process using a [`ParentBackchannel`].
#[derive(PartialEq, Eq)]
pub(super) enum MonitorMessage {
    Edge,
    Signal(c_int),
}

impl MonitorMessage {
    const LEN: usize = PREFIX_LEN + MONITOR_DATA_LEN;
    const EDGE_CMD: Prefix = 0;
    const SIGNAL: Prefix = 1;

    fn from_parts(prefix: Prefix, data: MonitorData) -> Self {
        match prefix {
            Self::EDGE_CMD => Self::Edge,
            Self::SIGNAL => Self::Signal(data),
            _ => unreachable!(),
        }
    }

    fn to_parts(&self) -> (Prefix, MonitorData) {
        let prefix = match self {
            MonitorMessage::Edge => Self::EDGE_CMD,
            MonitorMessage::Signal(_) => Self::SIGNAL,
        };

        let data = match self {
            MonitorMessage::Edge => 0,
            MonitorMessage::Signal(data) => *data,
        };

        (prefix, data)
    }
}

impl std::fmt::Debug for MonitorMessage {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Edge => "Edge".fmt(f),
            &Self::Signal(signal) => write!(f, "Signal({})", signal_fmt(signal)),
        }
    }
}

impl DeSerialize for MonitorMessage {
    type Bytes = [u8; MonitorMessage::LEN];

    fn serialize(&self) -> Self::Bytes {
        let mut buf = [0; MonitorMessage::LEN];

        let (prefix_buf, data_buf) = buf.split_at_mut(PREFIX_LEN);
        let (prefix, data) = self.to_parts();

        prefix_buf.copy_from_slice(&prefix.to_ne_bytes());
        data_buf.copy_from_slice(&data.to_ne_bytes());
        buf
    }

    fn deserialize(bytes: Self::Bytes) -> Self {
        let (prefix_buf, data_buf) = bytes.split_at(PREFIX_LEN);

        let prefix = Prefix::from_ne_bytes(prefix_buf.try_into().unwrap());
        let data = MonitorData::from_ne_bytes(data_buf.try_into().unwrap());

        MonitorMessage::from_parts(prefix, data)
    }
}

/// A socket use for communication between the monitor and the parent process.
pub(super) struct MonitorBackchannel {
    socket: BinPipe<MonitorMessage, ParentMessage>,
    #[cfg(debug_assertions)]
    nonblocking_asserts: bool,
}

impl MonitorBackchannel {
    /// Send a [`ParentMessage`].
    ///
    /// Calling this method will block until the socket is ready for writing.
    #[track_caller]
    pub(super) fn send(&mut self, event: &ParentMessage) -> io::Result<()> {
        self.socket.write(event).inspect_err(|_err| {
            #[cfg(debug_assertions)]
            if self.nonblocking_asserts {
                assert_ne!(_err.kind(), io::ErrorKind::WouldBlock);
            }
        })
    }

    /// Receive a [`MonitorMessage`].
    ///
    /// Calling this method will block until the socket is ready for reading.
    #[track_caller]
    pub(super) fn recv(&mut self) -> io::Result<MonitorMessage> {
        let msg = self.socket.read().inspect_err(|_err| {
            #[cfg(debug_assertions)]
            if self.nonblocking_asserts {
                assert_ne!(_err.kind(), io::ErrorKind::WouldBlock);
            }
        })?;
        Ok(msg)
    }

    pub(super) fn set_nonblocking_assertions(&mut self, _doit: bool) {
        #[cfg(debug_assertions)]
        {
            self.nonblocking_asserts = _doit;
        }
    }
}

impl AsFd for MonitorBackchannel {
    fn as_fd(&self) -> BorrowedFd<'_> {
        self.socket.as_fd()
    }
}