mount_watcher/
watch.rs

1//! Main module.
2
3use std::{
4    collections::HashSet, fs::File, io::ErrorKind, os::fd::AsRawFd, sync::Arc, thread::JoinHandle,
5    time::Duration,
6};
7
8use mio::{unix::SourceFd, Events, Interest, Poll, Token, Waker};
9use thiserror::Error;
10use timerfd::TimerFd;
11
12use crate::mount::ReadError;
13
14use super::mount::{read_proc_mounts, LinuxMount};
15
16/// `MountWatcher` allows to react to changes in the mounted filesystems.
17///
18/// # Stopping
19///
20/// When the `MountWatcher` is dropped, the background thread that drives the watcher is stopped, and the callback will never be called again.
21/// You can also call [`stop`](Self::stop).
22///
23/// Furthermore, you can stop the watcher from the event handler itself, by returning [`WatchControl::Stop`].
24///
25/// # Example (stop in handler)
26///
27/// ```no_run
28/// use mount_watcher::{MountWatcher, WatchControl};
29///
30/// let watch = MountWatcher::new(|event| {
31///     let added_mounts = event.mounted;
32///     let removed_mounts = event.unmounted;
33///     let stop_condition = todo!();
34///     if stop_condition {
35///         // I have found what I wanted, stop here.
36///         WatchControl::Stop
37///     } else {
38///         // Continue to watch, I still want events.
39///         WatchControl::Continue
40///     }
41/// }).unwrap();
42/// // Wait for the watcher to be stopped by the handler
43/// watch.join().unwrap();
44/// ```
45pub struct MountWatcher {
46    thread_handle: Option<JoinHandle<()>>,
47    stop_waker: Arc<Waker>,
48}
49
50/// Error in `MountWatcher` setup.
51#[derive(Debug, Error)]
52#[error("MountWatcher setup error")]
53pub struct SetupError(#[source] ErrorImpl);
54
55/// Error in [`MountWatcher::stop`].
56#[derive(Debug, Error)]
57#[error("MountWatcher stop error")]
58pub struct StopError(#[source] ErrorImpl);
59
60/// Private error type: I don't want to expose it for the moment.
61#[derive(Debug, Error)]
62enum ErrorImpl {
63    #[error("read error")]
64    MountRead(#[from] ReadError),
65    #[error("failed to initialize epoll")]
66    PollInit(#[source] std::io::Error),
67    #[error("poll.poll() returned an error")]
68    PollPoll(#[source] std::io::Error),
69    #[error("failed to register a timer to epoll")]
70    PollTimer(#[source] std::io::Error),
71    #[error("could not set up a timer with delay {0:?} for event coalescing")]
72    Timerfd(Duration, #[source] std::io::Error),
73    #[error("failed to stop epoll from another thread")]
74    Stop(#[source] std::io::Error),
75}
76
77impl MountWatcher {
78    /// Watches the list of mounted filesystems and executes the `callback` when it changes.
79    pub fn new(
80        callback: impl FnMut(MountEvent) -> WatchControl + Send + 'static,
81    ) -> Result<Self, SetupError> {
82        watch_mounts(callback).map_err(SetupError)
83    }
84
85    /// Requests the background thread to terminate.
86    ///
87    /// To wait for the termination, use [`join`](Self::join).
88    pub fn stop(&self) -> Result<(), StopError> {
89        self.stop_waker
90            .wake()
91            .map_err(|e| StopError(ErrorImpl::Stop(e)))
92    }
93
94    /// Waits for the background thread to terminate.
95    ///
96    /// This blocks the current thread.
97    ///
98    /// # Errors
99    /// If the background thread has panicked, an error is returned with the panic payload.
100    pub fn join(mut self) -> std::thread::Result<()> {
101        self.thread_handle.take().unwrap().join()
102    }
103}
104
105impl Drop for MountWatcher {
106    fn drop(&mut self) {
107        if self.thread_handle.is_some() {
108            let _ = self.stop();
109        }
110    }
111}
112
113/// Event generated when the mounted filesystems change.
114pub struct MountEvent {
115    /// The new filesystems that have been mounted.
116    pub mounted: Vec<LinuxMount>,
117
118    /// The old filesystems that have been unmounted.
119    pub unmounted: Vec<LinuxMount>,
120
121    /// Indicates whether this is a coalesced event.
122    ///
123    /// See [`WatchControl::Coalesce`].
124    pub coalesced: bool,
125
126    /// Indicates whether this is the first event, which contains
127    /// the list of all the mounts.
128    pub initial: bool,
129}
130
131/// Value returned by the event handler to control the [`MountWatcher`].
132pub enum WatchControl {
133    /// Continue watching.
134    Continue,
135    /// Stop watching.
136    Stop,
137    /// After the given delay, call the callback again.
138    ///
139    /// In the event, the current mounts/unmounts will be included, in addition to the
140    /// new mounts/unmounts that will occur during the delay.
141    Coalesce { delay: Duration },
142}
143
144const MOUNT_TOKEN: Token = Token(0);
145const TIMER_TOKEN: Token = Token(1);
146const STOP_TOKEN: Token = Token(2);
147const POLL_TIMEOUT: Duration = Duration::from_secs(5);
148const PROC_MOUNTS_PATH: &str = "/proc/mounts";
149
150struct State<F: FnMut(MountEvent) -> WatchControl> {
151    known_mounts: HashSet<LinuxMount>,
152    callback: F,
153    coalesce_timer: Option<TimerFd>,
154    coalescing: bool,
155}
156
157impl<F: FnMut(MountEvent) -> WatchControl> State<F> {
158    fn new(callback: F) -> Self {
159        Self {
160            known_mounts: HashSet::with_capacity(8),
161            callback,
162            coalesce_timer: None,
163            coalescing: false,
164        }
165    }
166
167    fn check_diff(
168        &mut self,
169        file: &mut File,
170        coalesced: bool,
171        initial: bool,
172    ) -> Result<WatchControl, ReadError> {
173        debug_assert!(
174            !(coalesced && !self.coalescing),
175            "inconsistent state: coalescing flag should be set before setting the trigger up"
176        );
177        if self.coalescing {
178            if coalesced {
179                // The timer has been triggered, clear the flag.
180                self.coalescing = false;
181            } else {
182                // We are coalescing the events, wait for the timer.
183                return Ok(WatchControl::Continue);
184            }
185        }
186
187        let mounts = read_proc_mounts(file)?;
188        let mounts = HashSet::from_iter(mounts);
189        let unmounted: Vec<&LinuxMount> = self.known_mounts.difference(&mounts).collect();
190        let mounted: Vec<&LinuxMount> = mounts.difference(&self.known_mounts).collect();
191        log::trace!("known_mounts: {:?}", self.known_mounts);
192        log::trace!("curr. mounts: {:?}", mounts);
193
194        if mounted.is_empty() && unmounted.is_empty() {
195            // Weird: we got a notification but nothing has changed?
196            // Perhaps something was undone between the moment we got the notification and
197            // the moment we read the /proc/mounts virtual file?
198            log::warn!("nothing changed");
199            return Ok(WatchControl::Continue);
200        }
201
202        // call the callback with the changes
203        let event = MountEvent {
204            mounted: mounted.into_iter().cloned().collect(),
205            unmounted: unmounted.into_iter().cloned().collect(),
206            coalesced,
207            initial,
208        };
209        let res = (self.callback)(event);
210        if !matches!(res, WatchControl::Coalesce { .. }) {
211            // When coalescing, don't save the new mounts, we'll compute
212            // the difference again and send the future result instead.
213            // On the contrary, when NOT coalescing, save the new mounts.
214            self.known_mounts = mounts;
215        }
216        // propagate the choice of the callback
217        Ok(res)
218    }
219
220    fn start_coalescing(&mut self, delay: Duration, poll: &Poll) -> Result<(), ErrorImpl> {
221        log::trace!("start coalescing for {delay:?}");
222        let mut register = false;
223        if self.coalesce_timer.is_none() {
224            // create the timer, don't register it yet because it is not configured
225            self.coalesce_timer = Some(TimerFd::new().map_err(|e| ErrorImpl::Timerfd(delay, e))?);
226            register = true;
227            log::trace!("timerfd created");
228        }
229
230        // configure the timer
231        let timer = self.coalesce_timer.as_mut().unwrap();
232        timer.set_state(
233            timerfd::TimerState::Oneshot(delay),
234            timerfd::SetTimeFlags::Default,
235        );
236
237        // register the timer to the epoll instance
238        if register {
239            let fd = timer.as_raw_fd();
240            let mut source = SourceFd(&fd);
241            poll.registry()
242                .register(&mut source, TIMER_TOKEN, Interest::READABLE)
243                .map_err(ErrorImpl::PollTimer)?;
244            log::trace!("timerfd registered");
245        }
246        // set the coalescing flag
247        self.coalescing = true;
248        Ok(())
249    }
250}
251
252/// Starts a background thread that uses [`mio::poll`] (backed by `epoll`) to detect changes to the mounted filesystem.
253fn watch_mounts<F: FnMut(MountEvent) -> WatchControl + Send + 'static>(
254    callback: F,
255) -> Result<MountWatcher, ErrorImpl> {
256    // Open the file that contains info about the mounted filesystems.
257    let mut file =
258        File::open(PROC_MOUNTS_PATH).map_err(|e| ErrorImpl::MountRead(ReadError::Io(e)))?;
259    let fd = file.as_raw_fd();
260    let mut fd = SourceFd(&fd);
261
262    // Prepare epoll.
263    let mut poll = Poll::new().map_err(ErrorImpl::PollInit)?;
264
265    // Create a mean to wake epoll from another thread.
266    let stop_waker = Waker::new(poll.registry(), STOP_TOKEN).map_err(ErrorImpl::PollInit)?;
267
268    // According to `man proc_mounts`, a filesystem mount or unmount causes
269    // `poll` and `epoll_wait` to mark the file as having a PRIORITY event.
270    poll.registry()
271        .register(&mut fd, MOUNT_TOKEN, Interest::PRIORITY)
272        .map_err(ErrorImpl::PollInit)?;
273
274    // Declare the polling loop separately to handle errors in a nicer way.
275    let poll_loop = move || -> Result<(), ErrorImpl> {
276        let mut events = Events::with_capacity(8); // we don't expect many events
277        let mut state = State::new(callback);
278
279        // While we were setting up epoll, some filesystems may have been mounted.
280        // Check that here to avoid any miss.
281        match state.check_diff(&mut file, false, true)? {
282            WatchControl::Continue => (),
283            WatchControl::Stop => return Ok(()),
284            WatchControl::Coalesce { delay } => {
285                state.start_coalescing(delay, &poll)?;
286            }
287        }
288
289        loop {
290            let poll_res = poll.poll(&mut events, Some(POLL_TIMEOUT));
291            if let Err(e) = poll_res {
292                if e.kind() == ErrorKind::Interrupted {
293                    continue; // retry
294                } else {
295                    return Err(ErrorImpl::PollPoll(e)); // propagate error
296                }
297            }
298
299            // Call next() because we are not interested in each individual event.
300            // If the timeout elapses, the event list is empty.
301            if let Some(event) = events.iter().next() {
302                log::debug!("event on /proc/mounts: {event:?}");
303
304                // the stop_waker has been triggered, which means that we must stop now
305                if event.token() == STOP_TOKEN {
306                    break; // stop
307                }
308
309                // parse mount file and react to changes
310                let coalesced = event.token() == TIMER_TOKEN;
311                match state.check_diff(&mut file, coalesced, false)? {
312                    WatchControl::Continue => (),
313                    WatchControl::Stop => break,
314                    WatchControl::Coalesce { delay } => {
315                        state.start_coalescing(delay, &poll)?;
316                    }
317                }
318            }
319        }
320        Ok(())
321    };
322
323    // Spawn a thread.
324    let thread_handle = std::thread::spawn(move || {
325        if let Err(e) = poll_loop() {
326            log::error!("error in polling loop: {e:?}");
327        }
328    });
329
330    // Return a structure that will stop the polling when dropped.
331    Ok(MountWatcher {
332        thread_handle: Some(thread_handle),
333        stop_waker: Arc::new(stop_waker),
334    })
335}