sudo-rs 0.2.13

A memory safe implementation of sudo and su.
Documentation
use crate::{cutils::cerr, system::make_zeroed_sigaction};

use super::{SignalNumber, handler::SignalHandlerBehavior};

use std::ffi::c_int;
use std::io;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicBool, Ordering};

#[repr(transparent)]
pub(super) struct SignalAction {
    raw: libc::sigaction,
}

impl SignalAction {
    pub(super) fn new(behavior: SignalHandlerBehavior) -> io::Result<Self> {
        // This guarantees that functions won't be interrupted by this signal as long as the
        // handler is alive.
        let mut sa_flags = libc::SA_RESTART;

        // We only need a full `sa_mask` if we are going to stream the signal information as we
        // don't want to be interrupted by any signals while executing `send_siginfo`.
        let (sa_sigaction, sa_mask) = match behavior {
            SignalHandlerBehavior::Default => (libc::SIG_DFL, SignalSet::empty()?),
            SignalHandlerBehavior::Ignore => (libc::SIG_IGN, SignalSet::empty()?),
            SignalHandlerBehavior::Stream => {
                // Specify that we want to pass a signal-catching function in `sa_sigaction`.
                sa_flags |= libc::SA_SIGINFO;
                (
                    super::stream::send_siginfo as *const () as libc::sighandler_t,
                    SignalSet::full()?,
                )
            }
            SignalHandlerBehavior::StorePending => (
                store_pending as *const () as libc::sighandler_t,
                SignalSet::full()?,
            ),
        };

        let mut raw: libc::sigaction = make_zeroed_sigaction();
        raw.sa_sigaction = sa_sigaction;
        raw.sa_mask = sa_mask.raw;
        raw.sa_flags = sa_flags;

        Ok(Self { raw })
    }

    pub(super) fn register(&self, signal: SignalNumber) -> io::Result<Self> {
        let mut original_action = MaybeUninit::<Self>::zeroed();

        // SAFETY: `sigaction` expects a valid pointer, which we provide; the typecast is valid
        // since SignalAction is a repr(transparent) newtype struct.
        cerr(unsafe { libc::sigaction(signal, &self.raw, original_action.as_mut_ptr().cast()) })?;

        // SAFETY: `sigaction` will have properly initialized `original_action`.
        Ok(unsafe { original_action.assume_init() })
    }
}

static PENDING_SIGNALS: [AtomicBool; 64] = [const { AtomicBool::new(false) }; 64];

fn store_pending(signal: SignalNumber) {
    PENDING_SIGNALS[signal as usize].store(true, Ordering::SeqCst);
}

pub(crate) fn take_first_pending() -> Option<SignalNumber> {
    for (signal, val) in PENDING_SIGNALS.iter().enumerate() {
        if val.swap(false, Ordering::SeqCst) {
            return Some(SignalNumber::try_from(signal).unwrap());
        }
    }
    None
}

// A signal set that can be used to mask signals.
#[repr(transparent)]
pub(crate) struct SignalSet {
    raw: libc::sigset_t,
}

impl SignalSet {
    /// Create an empty set.
    pub(crate) fn empty() -> io::Result<Self> {
        let mut set = MaybeUninit::<Self>::zeroed();

        // SAFETY: same as above
        cerr(unsafe { libc::sigemptyset(set.as_mut_ptr().cast()) })?;

        // SAFETY: `sigemptyset` will have initialized `set`
        Ok(unsafe { set.assume_init() })
    }

    /// Create a set containing all the signals.
    pub(crate) fn full() -> io::Result<Self> {
        let mut set = MaybeUninit::<Self>::zeroed();

        // SAFETY: same as above
        cerr(unsafe { libc::sigfillset(set.as_mut_ptr().cast()) })?;

        // SAFETY: `sigfillset` will have initialized `set`
        Ok(unsafe { set.assume_init() })
    }

    /// Add a signal to this set
    pub(crate) fn add(&mut self, sig: SignalNumber) -> io::Result<()> {
        // SAFETY: we pass a valid mutable pointer to `sigaddset`
        cerr(unsafe { libc::sigaddset(&mut self.raw, sig) })?;

        Ok(())
    }

    fn sigprocmask(&self, how: c_int) -> io::Result<Self> {
        let mut original_set = MaybeUninit::<Self>::zeroed();

        // SAFETY: same as above
        cerr(unsafe { libc::sigprocmask(how, &self.raw, original_set.as_mut_ptr().cast()) })?;

        // SAFETY: `sigprocmask` will have initialized `set`
        Ok(unsafe { original_set.assume_init() })
    }

    /// Block all the signals in this set and return the previous set of blocked signals.
    ///
    /// After calling this function successfully, the set of blocked signals will be the union of
    /// the previous set of blocked signals and this set.
    pub(crate) fn block(&self) -> io::Result<Self> {
        self.sigprocmask(libc::SIG_BLOCK)
    }

    /// Unblock all the signals in this set and return the previous set of blocked signals.
    ///
    /// After calling this function successfully, the set of blocked signals will be the previous
    /// set of blocked signals without this set.
    pub(crate) fn unblock(&self) -> io::Result<Self> {
        self.sigprocmask(libc::SIG_UNBLOCK)
    }

    /// Block only the signals that are in this set and return the previous set of blocked signals.
    ///
    /// After calling this function successfully, the set of blocked signals will be the exactly
    /// this set.
    pub(crate) fn set_mask(&self) -> io::Result<Self> {
        self.sigprocmask(libc::SIG_SETMASK)
    }
}