mount_watcher/
watch.rs

1//! Main module.
2
3use std::{
4    collections::HashSet,
5    fs::File,
6    io::ErrorKind,
7    os::fd::AsRawFd,
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        Arc,
11    },
12    thread::JoinHandle,
13    time::Duration,
14};
15
16use mio::{unix::SourceFd, Events, Interest, Poll, Token};
17use thiserror::Error;
18use timerfd::TimerFd;
19
20use crate::mount::ReadError;
21
22use super::mount::{read_proc_mounts, LinuxMount};
23
24/// `MountWatcher` allows to react to changes in the mounted filesystems.
25///
26/// # Stopping
27///
28/// When the `MountWatcher` is dropped, the background thread that drives the watcher is stopped, and the callback will never be called again.
29/// You can also call [`stop`](Self::stop).
30///
31/// Furthermore, you can stop the watcher from the event handler itself, by returning [`WatchControl::Stop`].
32///
33/// # Example (stop in handler)
34///
35/// ```no_run
36/// use mount_watcher::{MountWatcher, WatchControl};
37///
38/// let watch = MountWatcher::new(|event| {
39///     let added_mounts = event.mounted;
40///     let removed_mounts = event.unmounted;
41///     let stop_condition = todo!();
42///     if stop_condition {
43///         // I have found what I wanted, stop here.
44///         WatchControl::Stop
45///     } else {
46///         // Continue to watch, I still want events.
47///         WatchControl::Continue
48///     }
49/// }).unwrap();
50/// // Wait for the watcher to be stopped by the handler
51/// watch.join().unwrap();
52/// ```
53pub struct MountWatcher {
54    thread_handle: Option<JoinHandle<()>>,
55    stop_flag: Arc<AtomicBool>,
56}
57
58/// Error in `MountWatcher` setup.
59#[derive(Debug, Error)]
60#[error("MountWatcher setup error")]
61pub struct SetupError(#[source] ErrorImpl);
62
63/// Private error type: I don't want to expose it for the moment.
64#[derive(Debug, Error)]
65enum ErrorImpl {
66    #[error("read error")]
67    MountRead(#[from] ReadError),
68    #[error("failed to initialize epoll")]
69    PollInit(#[source] std::io::Error),
70    #[error("poll.poll() returned an error")]
71    PollPoll(#[source] std::io::Error),
72    #[error("failed to register a timer to epoll")]
73    PollTimer(#[source] std::io::Error),
74    #[error("could not set up a timer with delay {0:?} for event coalescing")]
75    Timerfd(Duration, #[source] std::io::Error),
76}
77
78impl MountWatcher {
79    /// Watches the list of mounted filesystems and executes the `callback` when it changes.
80    pub fn new(
81        callback: impl FnMut(MountEvent) -> WatchControl + Send + 'static,
82    ) -> Result<Self, SetupError> {
83        watch_mounts(callback).map_err(SetupError)
84    }
85
86    /// Requests the background thread to terminate.
87    ///
88    /// To wait for the termination, use [`join`].
89    pub fn stop(&self) {
90        self.stop_flag.store(true, Ordering::Relaxed);
91    }
92
93    /// Waits for the background thread to terminate.
94    ///
95    /// This blocks the current thread.
96    ///
97    /// # Errors
98    /// If the background thread has panicked, an error is returned with the panic payload.
99    pub fn join(mut self) -> std::thread::Result<()> {
100        self.thread_handle.take().unwrap().join()
101    }
102}
103
104impl Drop for MountWatcher {
105    fn drop(&mut self) {
106        if self.thread_handle.is_some() {
107            self.stop();
108        }
109    }
110}
111
112/// Event generated when the mounted filesystems change.
113pub struct MountEvent {
114    /// The new filesystems that have been mounted.
115    pub mounted: Vec<LinuxMount>,
116
117    /// The old filesystems that have been unmounted.
118    pub unmounted: Vec<LinuxMount>,
119
120    /// Indicates whether this is a coalesced event.
121    ///
122    /// See [`WatchControl::Coalesce`].
123    pub coalesced: bool,
124
125    /// Indicates whether this is the first event, which contains
126    /// the list of all the mounts.
127    pub initial: bool,
128}
129
130/// Value returned by the event handler to control the [`MountWatcher`].
131pub enum WatchControl {
132    /// Continue watching.
133    Continue,
134    /// Stop watching.
135    Stop,
136    /// After the given delay, call the callback again.
137    ///
138    /// In the event, the current mounts/unmounts will be included, in addition to the
139    /// new mounts/unmounts that will occur during the delay.
140    Coalesce { delay: Duration },
141}
142
143const MOUNT_TOKEN: Token = Token(0);
144const TIMER_TOKEN: Token = Token(1);
145const POLL_TIMEOUT: Duration = Duration::from_secs(5);
146const PROC_MOUNTS_PATH: &str = "/proc/mounts";
147
148struct State<F: FnMut(MountEvent) -> WatchControl> {
149    known_mounts: HashSet<LinuxMount>,
150    callback: F,
151    coalesce_timer: Option<TimerFd>,
152    coalescing: bool,
153}
154
155impl<F: FnMut(MountEvent) -> WatchControl> State<F> {
156    fn new(callback: F) -> Self {
157        Self {
158            known_mounts: HashSet::with_capacity(8),
159            callback,
160            coalesce_timer: None,
161            coalescing: false,
162        }
163    }
164
165    fn check_diff(
166        &mut self,
167        file: &mut File,
168        coalesced: bool,
169        initial: bool,
170    ) -> Result<WatchControl, ReadError> {
171        debug_assert!(
172            !(coalesced && !self.coalescing),
173            "inconsistent state: coalescing flag should be set before setting the trigger up"
174        );
175        if self.coalescing {
176            if coalesced {
177                // The timer has been triggered, clear the flag.
178                self.coalescing = false;
179            } else {
180                // We are coalescing the events, wait for the timer.
181                return Ok(WatchControl::Continue);
182            }
183        }
184
185        let mounts = read_proc_mounts(file)?;
186        let mounts = HashSet::from_iter(mounts);
187        let unmounted: Vec<&LinuxMount> = self.known_mounts.difference(&mounts).collect();
188        let mounted: Vec<&LinuxMount> = mounts.difference(&self.known_mounts).collect();
189        log::trace!("known_mounts: {:?}", self.known_mounts);
190        log::trace!("curr. mounts: {:?}", mounts);
191
192        if mounted.is_empty() && unmounted.is_empty() {
193            // Weird: we got a notification but nothing has changed?
194            // Perhaps something was undone between the moment we got the notification and
195            // the moment we read the /proc/mounts virtual file?
196            log::warn!("nothing changed");
197            return Ok(WatchControl::Continue);
198        }
199
200        // call the callback with the changes
201        let event = MountEvent {
202            mounted: mounted.into_iter().cloned().collect(),
203            unmounted: unmounted.into_iter().cloned().collect(),
204            coalesced,
205            initial,
206        };
207        let res = (self.callback)(event);
208        if !matches!(res, WatchControl::Coalesce { .. }) {
209            // When coalescing, don't save the new mounts, we'll compute
210            // the difference again and send the future result instead.
211            // On the contrary, when NOT coalescing, save the new mounts.
212            self.known_mounts = mounts;
213        }
214        // propagate the choice of the callback
215        Ok(res)
216    }
217
218    fn start_coalescing(&mut self, delay: Duration, poll: &Poll) -> Result<(), ErrorImpl> {
219        log::trace!("start coalescing for {delay:?}");
220        let mut register = false;
221        if self.coalesce_timer.is_none() {
222            // create the timer, don't register it yet because it is not configured
223            self.coalesce_timer = Some(TimerFd::new().map_err(|e| ErrorImpl::Timerfd(delay, e))?);
224            register = true;
225            log::trace!("timerfd created");
226        }
227
228        // configure the timer
229        let timer = self.coalesce_timer.as_mut().unwrap();
230        timer.set_state(
231            timerfd::TimerState::Oneshot(delay),
232            timerfd::SetTimeFlags::Default,
233        );
234
235        // register the timer to the epoll instance
236        if register {
237            let fd = timer.as_raw_fd();
238            let mut source = SourceFd(&fd);
239            poll.registry()
240                .register(&mut source, TIMER_TOKEN, Interest::READABLE)
241                .map_err(ErrorImpl::PollTimer)?;
242            log::trace!("timerfd registered");
243        }
244        // set the coalescing flag
245        self.coalescing = true;
246        Ok(())
247    }
248}
249
250/// Starts a background thread that uses [`mio::poll`] (backed by `epoll`) to detect changes to the mounted filesystem.
251fn watch_mounts<F: FnMut(MountEvent) -> WatchControl + Send + 'static>(
252    callback: F,
253) -> Result<MountWatcher, ErrorImpl> {
254    // Open the file that contains info about the mounted filesystems.
255    let mut file =
256        File::open(PROC_MOUNTS_PATH).map_err(|e| ErrorImpl::MountRead(ReadError::Io(e)))?;
257    let fd = file.as_raw_fd();
258    let mut fd = SourceFd(&fd);
259
260    // Prepare epoll.
261    // According to `man proc_mounts`, a filesystem mount or unmount causes
262    // `poll` and `epoll_wait` to mark the file as having a PRIORITY event.
263    let mut poll = Poll::new().map_err(|e| ErrorImpl::PollInit(e))?;
264    poll.registry()
265        .register(&mut fd, MOUNT_TOKEN, Interest::PRIORITY)
266        .map_err(|e| ErrorImpl::PollInit(e))?;
267
268    // Keep a boolean to stop the thread from the outside.
269    let stop_flag = Arc::new(AtomicBool::new(false));
270    let stop_flag_thread = stop_flag.clone();
271
272    // Declare the polling loop separately to handle errors in a nicer way.
273    let poll_loop = move || -> Result<(), ErrorImpl> {
274        let mut events = Events::with_capacity(8); // we don't expect many events
275        let mut state = State::new(callback);
276
277        // While we were setting up epoll, some filesystems may have been mounted.
278        // Check that here to avoid any miss.
279        match state.check_diff(&mut file, false, true)? {
280            WatchControl::Continue => (),
281            WatchControl::Stop => return Ok(()),
282            WatchControl::Coalesce { delay } => {
283                state.start_coalescing(delay, &poll)?;
284            }
285        }
286
287        loop {
288            let poll_res = poll.poll(&mut events, Some(POLL_TIMEOUT));
289            if let Err(e) = poll_res {
290                if e.kind() == ErrorKind::Interrupted {
291                    continue; // retry
292                } else {
293                    return Err(ErrorImpl::PollPoll(e)); // propagate error
294                }
295            }
296
297            // Call next() because we are not interested in each individual event.
298            // If the timeout elapses, the event list is empty.
299            if let Some(event) = events.iter().next() {
300                log::debug!("event on /proc/mounts: {event:?}");
301
302                // parse mount file and react to changes
303                let coalesced = dbg!(event.token() == TIMER_TOKEN);
304                match state.check_diff(&mut file, coalesced, false)? {
305                    WatchControl::Continue => (),
306                    WatchControl::Stop => break,
307                    WatchControl::Coalesce { delay } => {
308                        state.start_coalescing(delay, &poll)?;
309                    }
310                }
311            }
312            if stop_flag_thread.load(Ordering::Relaxed) {
313                break;
314            }
315        }
316        Ok(())
317    };
318
319    // Spawn a thread.
320    let thread_handle = std::thread::spawn(move || {
321        if let Err(e) = poll_loop() {
322            log::error!("error in poll loop: {e:?}");
323        }
324    });
325
326    // Return a structure that will stop the polling when dropped.
327    Ok(MountWatcher {
328        thread_handle: Some(thread_handle),
329        stop_flag,
330    })
331}