Skip to main content

brush_core/
traps.rs

1//! Facilities for configuring trap handlers.
2
3use std::str::FromStr;
4use std::{collections::HashMap, fmt::Display};
5
6use itertools::Itertools as _;
7
8use crate::{error, sys};
9
10/// Type of signal that can be trapped in the shell.
11#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
12pub enum TrapSignal {
13    /// A system signal.
14    Signal(sys::signal::Signal),
15    /// The `DEBUG` trap.
16    Debug,
17    /// The `ERR` trap.
18    Err,
19    /// The `EXIT` trap.
20    Exit,
21    /// The `RETURN` trp.
22    Return,
23}
24
25#[cfg(feature = "serde")]
26impl serde::Serialize for TrapSignal {
27    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
28    where
29        S: serde::Serializer,
30    {
31        serializer.serialize_str(self.as_str())
32    }
33}
34
35#[cfg(feature = "serde")]
36impl<'de> serde::Deserialize<'de> for TrapSignal {
37    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38    where
39        D: serde::Deserializer<'de>,
40    {
41        let s = String::deserialize(deserializer)?;
42        Self::try_from(s.as_str()).map_err(serde::de::Error::custom)
43    }
44}
45
46impl Display for TrapSignal {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.write_str(self.as_str())
49    }
50}
51
52impl TrapSignal {
53    /// Returns all possible values of [`TrapSignal`].
54    pub fn iterator() -> impl Iterator<Item = Self> {
55        const SIGNALS: &[TrapSignal] = &[TrapSignal::Debug, TrapSignal::Err, TrapSignal::Exit];
56
57        let iter = itertools::chain!(
58            SIGNALS.iter().copied(),
59            sys::signal::Signal::iterator().map(TrapSignal::Signal)
60        );
61
62        iter
63    }
64
65    /// Converts [`TrapSignal`] into its corresponding signal name as a [`&'static str`](str)
66    pub const fn as_str(self) -> &'static str {
67        match self {
68            Self::Signal(s) => s.as_str(),
69            Self::Debug => "DEBUG",
70            Self::Err => "ERR",
71            Self::Exit => "EXIT",
72            Self::Return => "RETURN",
73        }
74    }
75}
76
77/// Formats [`Iterator<Item = TrapSignal>`](TrapSignal)  to the provided writer.
78///
79/// # Arguments
80///
81/// * `f` - Any type that implements [`std::io::Write`].
82/// * `it` - An iterator over the signals that will be formatted into the `f`.
83pub fn format_signals(
84    mut f: impl std::io::Write,
85    it: impl Iterator<Item = TrapSignal>,
86) -> Result<(), error::Error> {
87    let it = it
88        .filter_map(|s| i32::try_from(s).ok().map(|n| (s, n)))
89        .sorted_by(|a, b| Ord::cmp(&a.1, &b.1))
90        .format_with("\n", |s, f| f(&format_args!("{}) {}", s.1, s.0)));
91    write!(f, "{it}")?;
92    Ok(())
93}
94
95// implement s.parse::<TrapSignal>()
96impl FromStr for TrapSignal {
97    type Err = error::Error;
98    fn from_str(s: &str) -> Result<Self, <Self as FromStr>::Err> {
99        if let Ok(n) = s.parse::<i32>() {
100            Self::try_from(n)
101        } else {
102            Self::try_from(s)
103        }
104    }
105}
106
107// from a signal number
108impl TryFrom<i32> for TrapSignal {
109    type Error = error::Error;
110    fn try_from(value: i32) -> Result<Self, Self::Error> {
111        // NOTE: DEBUG and ERR are real-time signals, defined based on NSIG or SIGRTMAX (is not
112        // available on bsd-like systems),
113        // and don't have persistent numbers across platforms, so we skip them here.
114        Ok(match value {
115            0 => Self::Exit,
116            value => Self::Signal(
117                sys::signal::Signal::try_from(value)
118                    .map_err(|_| error::ErrorKind::InvalidSignal(value.to_string()))?,
119            ),
120        })
121    }
122}
123
124// from a signal name
125impl TryFrom<&str> for TrapSignal {
126    type Error = error::Error;
127    fn try_from(value: &str) -> Result<Self, Self::Error> {
128        #[allow(unused_mut, reason = "only mutated on some platforms")]
129        let mut s = value.to_ascii_uppercase();
130
131        Ok(match s.as_str() {
132            "DEBUG" => Self::Debug,
133            "ERR" => Self::Err,
134            "EXIT" => Self::Exit,
135            "RETURN" => Self::Return,
136            _ => {
137                // Bash compatibility:
138                // support for signal names without the `SIG` prefix, for example `HUP` -> `SIGHUP`
139                if !s.starts_with("SIG") {
140                    s.insert_str(0, "SIG");
141                }
142                sys::signal::Signal::from_str(s.as_str())
143                    .map(TrapSignal::Signal)
144                    .map_err(|_| error::ErrorKind::InvalidSignal(value.into()))?
145            }
146        })
147    }
148}
149
150/// Error type used when failing to convert a `TrapSignal` to a number.
151#[derive(Debug, Clone, Copy)]
152pub struct TrapSignalNumberError;
153
154impl TryFrom<TrapSignal> for i32 {
155    type Error = TrapSignalNumberError;
156    fn try_from(value: TrapSignal) -> Result<Self, Self::Error> {
157        Ok(match value {
158            TrapSignal::Signal(s) => s as Self,
159            TrapSignal::Exit => 0,
160            _ => return Err(TrapSignalNumberError),
161        })
162    }
163}
164
165/// A handler for a trap signal.
166#[derive(Clone, Default)]
167#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
168pub struct TrapHandler {
169    /// The source text of the command to invoke.
170    pub command: String,
171    /// Source information for where the trap handler was defined.
172    pub source_info: crate::SourceInfo,
173}
174
175/// Configuration for trap handlers in the shell.
176#[derive(Clone, Default)]
177#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
178pub struct TrapHandlerConfig {
179    /// Registered handlers for traps; maps signal type to command.
180    handlers: HashMap<TrapSignal, TrapHandler>,
181}
182
183impl TrapHandlerConfig {
184    /// Iterates over the registered handlers for trap signals.
185    pub fn iter_handlers(&self) -> impl Iterator<Item = (TrapSignal, &TrapHandler)> {
186        self.handlers
187            .iter()
188            .map(|(signal, handler)| (*signal, handler))
189    }
190
191    /// Tries to find the handler associated with the given signal.
192    ///
193    /// # Arguments
194    ///
195    /// * `signal_type` - The type of signal to get the handler for.
196    pub fn get_handler(&self, signal_type: TrapSignal) -> Option<&TrapHandler> {
197        self.handlers.get(&signal_type)
198    }
199
200    /// Returns whether a handler is registered for the given signal.
201    pub fn handles(&self, signal_type: TrapSignal) -> bool {
202        self.handlers.contains_key(&signal_type)
203    }
204
205    /// Registers a handler for a trap signal.
206    ///
207    /// # Arguments
208    ///
209    /// * `signal_type` - The type of signal to register a handler for.
210    /// * `command` - The command to execute when the signal is trapped.
211    /// * `source_info` - The source info for where the trap handler was defined.
212    pub fn register_handler(
213        &mut self,
214        signal_type: TrapSignal,
215        command: String,
216        source_info: crate::SourceInfo,
217    ) {
218        let _ = self.handlers.insert(
219            signal_type,
220            TrapHandler {
221                command,
222                source_info,
223            },
224        );
225    }
226
227    /// Removes handlers for a trap signal.
228    ///
229    /// # Arguments
230    ///
231    /// * `signal_type` - The type of signal to remove handlers for.
232    pub fn remove_handlers(&mut self, signal_type: TrapSignal) {
233        self.handlers.remove(&signal_type);
234    }
235}