brush_core/
traps.rs

1//! Facilities for configuring trap handlers.
2
3use std::str::FromStr;
4use std::{collections::HashMap, fmt::Display};
5
6use itertools::Itertools as _;
7
8use crate::error;
9
10/// Type of signal that can be trapped in the shell.
11#[derive(Clone, Copy, Eq, Hash, PartialEq)]
12pub enum TrapSignal {
13    /// A system signal.
14    #[cfg(unix)]
15    Signal(nix::sys::signal::Signal),
16    /// The `DEBUG` trap.
17    Debug,
18    /// The `ERR` trap.
19    Err,
20    /// The `EXIT` trap.
21    Exit,
22    /// The `RETURN` trp.
23    Return,
24}
25
26impl Display for TrapSignal {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.write_str(self.as_str())
29    }
30}
31
32impl TrapSignal {
33    /// Returns all possible values of [`TrapSignal`].
34    pub fn iterator() -> impl Iterator<Item = Self> {
35        const SIGNALS: &[TrapSignal] = &[TrapSignal::Debug, TrapSignal::Err, TrapSignal::Exit];
36        let iter = SIGNALS.iter().copied();
37
38        #[cfg(unix)]
39        let iter = itertools::chain!(
40            iter,
41            nix::sys::signal::Signal::iterator().map(TrapSignal::Signal)
42        );
43
44        iter
45    }
46
47    /// Converts [`TrapSignal`] into its corresponding signal name as a [`&'static str`](str)
48    pub const fn as_str(self) -> &'static str {
49        match self {
50            #[cfg(unix)]
51            Self::Signal(s) => s.as_str(),
52            Self::Debug => "DEBUG",
53            Self::Err => "ERR",
54            Self::Exit => "EXIT",
55            Self::Return => "RETURN",
56        }
57    }
58}
59
60/// Formats [`Iterator<Item = TrapSignal>`](TrapSignal)  to the provided writer.
61///
62/// # Arguments
63///
64/// * `f` - Any type that implements [`std::io::Write`].
65/// * `it` - An iterator over the signals that will be formatted into the `f`.
66pub fn format_signals(
67    mut f: impl std::io::Write,
68    it: impl Iterator<Item = TrapSignal>,
69) -> Result<(), error::Error> {
70    let it = it
71        .filter_map(|s| i32::try_from(s).ok().map(|n| (s, n)))
72        .sorted_by(|a, b| Ord::cmp(&a.1, &b.1))
73        .format_with("\n", |s, f| f(&format_args!("{}) {}", s.1, s.0)));
74    write!(f, "{it}")?;
75    Ok(())
76}
77
78// implement s.parse::<TrapSignal>()
79impl FromStr for TrapSignal {
80    type Err = error::Error;
81    fn from_str(s: &str) -> Result<Self, <Self as FromStr>::Err> {
82        if let Ok(n) = s.parse::<i32>() {
83            Self::try_from(n)
84        } else {
85            Self::try_from(s)
86        }
87    }
88}
89
90// from a signal number
91impl TryFrom<i32> for TrapSignal {
92    type Error = error::Error;
93    fn try_from(value: i32) -> Result<Self, Self::Error> {
94        // NOTE: DEBUG and ERR are real-time signals, defined based on NSIG or SIGRTMAX (is not
95        // available on bsd-like systems),
96        // and don't have persistent numbers across platforms, so we skip them here.
97        Ok(match value {
98            0 => Self::Exit,
99            #[cfg(unix)]
100            value => Self::Signal(
101                nix::sys::signal::Signal::try_from(value)
102                    .map_err(|_| error::Error::InvalidSignal(value.to_string()))?,
103            ),
104            #[cfg(not(unix))]
105            _ => return Err(error::Error::InvalidSignal(value.to_string())),
106        })
107    }
108}
109
110// from a signal name
111impl TryFrom<&str> for TrapSignal {
112    type Error = error::Error;
113    fn try_from(value: &str) -> Result<Self, Self::Error> {
114        #[allow(unused_mut)] // on not unix platforms
115        let mut s = value.to_ascii_uppercase();
116
117        Ok(match s.as_str() {
118            "DEBUG" => Self::Debug,
119            "ERR" => Self::Err,
120            "EXIT" => Self::Exit,
121            "RETURN" => Self::Return,
122
123            #[cfg(unix)]
124            _ => {
125                // Bash compatibility:
126                // support for signal names without the `SIG` prefix, for example `HUP` -> `SIGHUP`
127                if !s.starts_with("SIG") {
128                    s.insert_str(0, "SIG");
129                }
130                nix::sys::signal::Signal::from_str(s.as_str())
131                    .map(TrapSignal::Signal)
132                    .map_err(|_| error::Error::InvalidSignal(value.into()))?
133            }
134            #[cfg(not(unix))]
135            _ => return Err(error::Error::InvalidSignal(value.into())),
136        })
137    }
138}
139
140/// Error type used when failing to convert a `TrapSignal` to a number.
141#[derive(Debug, Clone, Copy)]
142pub struct TrapSignalNumberError;
143
144impl TryFrom<TrapSignal> for i32 {
145    type Error = TrapSignalNumberError;
146    fn try_from(value: TrapSignal) -> Result<Self, Self::Error> {
147        Ok(match value {
148            #[cfg(unix)]
149            TrapSignal::Signal(s) => s as Self,
150            TrapSignal::Exit => 0,
151            _ => return Err(TrapSignalNumberError),
152        })
153    }
154}
155
156/// Configuration for trap handlers in the shell.
157#[derive(Clone, Default)]
158pub struct TrapHandlerConfig {
159    /// Registered handlers for traps; maps signal type to command.
160    pub(crate) handlers: HashMap<TrapSignal, String>,
161    /// Current depth of the handler stack.
162    pub(crate) handler_depth: i32,
163}
164
165impl TrapHandlerConfig {
166    /// Registers a handler for a trap signal.
167    ///
168    /// # Arguments
169    ///
170    /// * `signal_type` - The type of signal to register a handler for.
171    /// * `command` - The command to execute when the signal is trapped.
172    pub fn register_handler(&mut self, signal_type: TrapSignal, command: String) {
173        let _ = self.handlers.insert(signal_type, command);
174    }
175
176    /// Removes handlers for a trap signal.
177    ///
178    /// # Arguments
179    ///
180    /// * `signal_type` - The type of signal to remove handlers for.
181    pub fn remove_handlers(&mut self, signal_type: TrapSignal) {
182        self.handlers.remove(&signal_type);
183    }
184}