mount_watcher/
watch.rs

1//! Main module.
2
3use std::{
4    collections::HashSet,
5    fs::File,
6    io::ErrorKind,
7    os::fd::AsRawFd,
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        Arc,
11    },
12    thread::JoinHandle,
13    time::Duration,
14};
15
16use mio::{unix::SourceFd, Events, Interest, Poll, Token};
17use thiserror::Error;
18use timerfd::TimerFd;
19
20use crate::mount::ReadError;
21
22use super::mount::{read_proc_mounts, LinuxMount};
23
24/// `MountWatcher` allows to react to changes in the mounted filesystems.
25///
26/// # Stopping
27///
28/// When the `MountWatcher` is dropped, the background thread that drives the watcher is stopped, and the callback will never be called again.
29/// You can also call [`stop`](Self::stop).
30///
31/// Furthermore, you can stop the watcher from the event handler itself, by returning [`WatchControl::Stop`].
32///
33/// # Example (stop in handler)
34///
35/// ```no_run
36/// use mount_watcher::{MountWatcher, WatchControl};
37///
38/// let watch = MountWatcher::new(|event| {
39///     let added_mounts = event.mounted;
40///     let removed_mounts = event.unmounted;
41///     let stop_condition = todo!();
42///     if stop_condition {
43///         // I have found what I wanted, stop here.
44///         WatchControl::Stop
45///     } else {
46///         // Continue to watch, I still want events.
47///         WatchControl::Continue
48///     }
49/// }).unwrap();
50/// // Wait for the watcher to be stopped by the handler
51/// watch.join().unwrap();
52/// ```
53pub struct MountWatcher {
54    thread_handle: Option<JoinHandle<()>>,
55    stop_flag: Arc<AtomicBool>,
56}
57
58/// Error in `MountWatcher` setup.
59#[derive(Debug, Error)]
60#[error("MountWatcher setup error")]
61pub struct SetupError(#[source] ErrorImpl);
62
63/// Private error type: I don't want to expose it for the moment.
64#[derive(Debug, Error)]
65enum ErrorImpl {
66    #[error("read error")]
67    MountRead(#[from] ReadError),
68    #[error("failed to initialize epoll")]
69    PollInit(#[source] std::io::Error),
70    #[error("poll.poll() returned an error")]
71    PollPoll(#[source] std::io::Error),
72    #[error("failed to register a timer to epoll")]
73    PollTimer(#[source] std::io::Error),
74    #[error("could not set up a timer with delay {0:?} for event coalescing")]
75    Timerfd(Duration, #[source] std::io::Error),
76}
77
78impl MountWatcher {
79    /// Watches the list of mounted filesystems and executes the `callback` when it changes.
80    pub fn new(
81        callback: impl FnMut(MountEvent) -> WatchControl + Send + 'static,
82    ) -> Result<Self, SetupError> {
83        watch_mounts(callback).map_err(SetupError)
84    }
85
86    /// Stops the waiting thread and wait for it to terminate.
87    ///
88    /// # Errors
89    /// If the thread has panicked, an error is returned with the panic payload.
90    pub fn stop(mut self) -> std::thread::Result<()> {
91        self.stop_flag.store(true, Ordering::Relaxed);
92        self.thread_handle.take().unwrap().join()
93    }
94
95    pub fn join(mut self) -> std::thread::Result<()> {
96        self.thread_handle.take().unwrap().join()
97    }
98}
99
100impl Drop for MountWatcher {
101    fn drop(&mut self) {
102        if self.thread_handle.is_some() {
103            self.stop_flag.store(true, Ordering::Relaxed);
104        }
105    }
106}
107
108/// Event generated when the mounted filesystems change.
109pub struct MountEvent {
110    /// The new filesystems that have been mounted.
111    pub mounted: Vec<LinuxMount>,
112
113    /// The old filesystems that have been unmounted.
114    pub unmounted: Vec<LinuxMount>,
115
116    /// Indicates whether this is a coalesced event.
117    ///
118    /// See [`WatchControl::Coalesce`].
119    pub coalesced: bool,
120
121    /// Indicates whether this is the first event, which contains
122    /// the list of all the mounts.
123    pub initial: bool,
124}
125
126/// Value returned by the event handler to control the [`MountWatcher`].
127pub enum WatchControl {
128    /// Continue watching.
129    Continue,
130    /// Stop watching.
131    Stop,
132    /// After the given delay, call the callback again.
133    ///
134    /// In the event, the current mounts/unmounts will be included, in addition to the
135    /// new mounts/unmounts that will occur during the delay.
136    Coalesce { delay: Duration },
137}
138
139const MOUNT_TOKEN: Token = Token(0);
140const TIMER_TOKEN: Token = Token(1);
141const POLL_TIMEOUT: Duration = Duration::from_secs(5);
142const PROC_MOUNTS_PATH: &str = "/proc/mounts";
143
144struct State<F: FnMut(MountEvent) -> WatchControl> {
145    known_mounts: HashSet<LinuxMount>,
146    callback: F,
147    coalesce_timer: Option<TimerFd>,
148    coalescing: bool,
149}
150
151impl<F: FnMut(MountEvent) -> WatchControl> State<F> {
152    fn new(callback: F) -> Self {
153        Self {
154            known_mounts: HashSet::with_capacity(8),
155            callback,
156            coalesce_timer: None,
157            coalescing: false,
158        }
159    }
160
161    fn check_diff(
162        &mut self,
163        file: &mut File,
164        coalesced: bool,
165        initial: bool,
166    ) -> Result<WatchControl, ReadError> {
167        debug_assert!(
168            !(coalesced && !self.coalescing),
169            "inconsistent state: coalescing flag should be set before setting the trigger up"
170        );
171        if self.coalescing {
172            if coalesced {
173                // The timer has been triggered, clear the flag.
174                self.coalescing = false;
175            } else {
176                // We are coalescing the events, wait for the timer.
177                return Ok(WatchControl::Continue);
178            }
179        }
180
181        let mounts = read_proc_mounts(file)?;
182        let mounts = HashSet::from_iter(mounts);
183        let unmounted: Vec<&LinuxMount> = self.known_mounts.difference(&mounts).collect();
184        let mounted: Vec<&LinuxMount> = mounts.difference(&self.known_mounts).collect();
185        log::trace!("known_mounts: {:?}", self.known_mounts);
186        log::trace!("curr. mounts: {:?}", mounts);
187
188        if mounted.is_empty() && unmounted.is_empty() {
189            // Weird: we got a notification but nothing has changed?
190            // Perhaps something was undone between the moment we got the notification and
191            // the moment we read the /proc/mounts virtual file?
192            log::warn!("nothing changed");
193            return Ok(WatchControl::Continue);
194        }
195
196        // call the callback with the changes
197        let event = MountEvent {
198            mounted: mounted.into_iter().cloned().collect(),
199            unmounted: unmounted.into_iter().cloned().collect(),
200            coalesced,
201            initial,
202        };
203        let res = (self.callback)(event);
204        if !matches!(res, WatchControl::Coalesce { .. }) {
205            // When coalescing, don't save the new mounts, we'll compute
206            // the difference again and send the future result instead.
207            // On the contrary, when NOT coalescing, save the new mounts.
208            self.known_mounts = mounts;
209        }
210        // propagate the choice of the callback
211        Ok(res)
212    }
213
214    fn start_coalescing(&mut self, delay: Duration, poll: &Poll) -> Result<(), ErrorImpl> {
215        log::trace!("start coalescing for {delay:?}");
216        let mut register = false;
217        if self.coalesce_timer.is_none() {
218            // create the timer, don't register it yet because it is not configured
219            self.coalesce_timer = Some(TimerFd::new().map_err(|e| ErrorImpl::Timerfd(delay, e))?);
220            register = true;
221            log::trace!("timerfd created");
222        }
223
224        // configure the timer
225        let timer = self.coalesce_timer.as_mut().unwrap();
226        timer.set_state(
227            timerfd::TimerState::Oneshot(delay),
228            timerfd::SetTimeFlags::Default,
229        );
230
231        // register the timer to the epoll instance
232        if register {
233            let fd = timer.as_raw_fd();
234            let mut source = SourceFd(&fd);
235            poll.registry()
236                .register(&mut source, TIMER_TOKEN, Interest::READABLE)
237                .map_err(ErrorImpl::PollTimer)?;
238            log::trace!("timerfd registered");
239        }
240        // set the coalescing flag
241        self.coalescing = true;
242        Ok(())
243    }
244}
245
246/// Starts a background thread that uses [`mio::poll`] (backed by `epoll`) to detect changes to the mounted filesystem.
247fn watch_mounts<F: FnMut(MountEvent) -> WatchControl + Send + 'static>(
248    callback: F,
249) -> Result<MountWatcher, ErrorImpl> {
250    // Open the file that contains info about the mounted filesystems.
251    let mut file =
252        File::open(PROC_MOUNTS_PATH).map_err(|e| ErrorImpl::MountRead(ReadError::Io(e)))?;
253    let fd = file.as_raw_fd();
254    let mut fd = SourceFd(&fd);
255
256    // Prepare epoll.
257    // According to `man proc_mounts`, a filesystem mount or unmount causes
258    // `poll` and `epoll_wait` to mark the file as having a PRIORITY event.
259    let mut poll = Poll::new().map_err(|e| ErrorImpl::PollInit(e))?;
260    poll.registry()
261        .register(&mut fd, MOUNT_TOKEN, Interest::PRIORITY)
262        .map_err(|e| ErrorImpl::PollInit(e))?;
263
264    // Keep a boolean to stop the thread from the outside.
265    let stop_flag = Arc::new(AtomicBool::new(false));
266    let stop_flag_thread = stop_flag.clone();
267
268    // Declare the polling loop separately to handle errors in a nicer way.
269    let poll_loop = move || -> Result<(), ErrorImpl> {
270        let mut events = Events::with_capacity(8); // we don't expect many events
271        let mut state = State::new(callback);
272
273        // While we were setting up epoll, some filesystems may have been mounted.
274        // Check that here to avoid any miss.
275        match state.check_diff(&mut file, false, true)? {
276            WatchControl::Continue => (),
277            WatchControl::Stop => return Ok(()),
278            WatchControl::Coalesce { delay } => {
279                state.start_coalescing(delay, &poll)?;
280            }
281        }
282
283        loop {
284            let poll_res = poll.poll(&mut events, Some(POLL_TIMEOUT));
285            if let Err(e) = poll_res {
286                if e.kind() == ErrorKind::Interrupted {
287                    continue; // retry
288                } else {
289                    return Err(ErrorImpl::PollPoll(e)); // propagate error
290                }
291            }
292
293            // Call next() because we are not interested in each individual event.
294            // If the timeout elapses, the event list is empty.
295            if let Some(event) = events.iter().next() {
296                log::debug!("event on /proc/mounts: {event:?}");
297
298                // parse mount file and react to changes
299                let coalesced = dbg!(event.token() == TIMER_TOKEN);
300                match state.check_diff(&mut file, coalesced, false)? {
301                    WatchControl::Continue => (),
302                    WatchControl::Stop => break,
303                    WatchControl::Coalesce { delay } => {
304                        state.start_coalescing(delay, &poll)?;
305                    }
306                }
307            }
308            if stop_flag_thread.load(Ordering::Relaxed) {
309                break;
310            }
311        }
312        Ok(())
313    };
314
315    // Spawn a thread.
316    let thread_handle = std::thread::spawn(move || {
317        if let Err(e) = poll_loop() {
318            log::error!("error in poll loop: {e:?}");
319        }
320    });
321
322    // Return a structure that will stop the polling when dropped.
323    Ok(MountWatcher {
324        thread_handle: Some(thread_handle),
325        stop_flag,
326    })
327}