1use std::{
4 collections::HashSet, fs::File, io::ErrorKind, os::fd::AsRawFd, sync::Arc, thread::JoinHandle,
5 time::Duration,
6};
7
8use mio::{unix::SourceFd, Events, Interest, Poll, Token, Waker};
9use thiserror::Error;
10use timerfd::TimerFd;
11
12use crate::mount::ReadError;
13
14use super::mount::{read_proc_mounts, LinuxMount};
15
16pub struct MountWatcher {
46 thread_handle: Option<JoinHandle<()>>,
47 stop_waker: Arc<Waker>,
48}
49
50#[derive(Debug, Error)]
52#[error("MountWatcher setup error")]
53pub struct SetupError(#[source] ErrorImpl);
54
55#[derive(Debug, Error)]
57#[error("MountWatcher stop error")]
58pub struct StopError(#[source] ErrorImpl);
59
60#[derive(Debug, Error)]
62enum ErrorImpl {
63 #[error("read error")]
64 MountRead(#[from] ReadError),
65 #[error("failed to initialize epoll")]
66 PollInit(#[source] std::io::Error),
67 #[error("poll.poll() returned an error")]
68 PollPoll(#[source] std::io::Error),
69 #[error("failed to register a timer to epoll")]
70 PollTimer(#[source] std::io::Error),
71 #[error("could not set up a timer with delay {0:?} for event coalescing")]
72 Timerfd(Duration, #[source] std::io::Error),
73 #[error("failed to stop epoll from another thread")]
74 Stop(#[source] std::io::Error),
75}
76
77impl MountWatcher {
78 pub fn new(
80 callback: impl FnMut(MountEvent) -> WatchControl + Send + 'static,
81 ) -> Result<Self, SetupError> {
82 watch_mounts(callback).map_err(SetupError)
83 }
84
85 pub fn stop(&self) -> Result<(), StopError> {
89 self.stop_waker
90 .wake()
91 .map_err(|e| StopError(ErrorImpl::Stop(e)))
92 }
93
94 pub fn join(mut self) -> std::thread::Result<()> {
101 self.thread_handle.take().unwrap().join()
102 }
103}
104
105impl Drop for MountWatcher {
106 fn drop(&mut self) {
107 if self.thread_handle.is_some() {
108 let _ = self.stop();
109 }
110 }
111}
112
113pub struct MountEvent {
115 pub mounted: Vec<LinuxMount>,
117
118 pub unmounted: Vec<LinuxMount>,
120
121 pub coalesced: bool,
125
126 pub initial: bool,
129}
130
131pub enum WatchControl {
133 Continue,
135 Stop,
137 Coalesce { delay: Duration },
142}
143
144const MOUNT_TOKEN: Token = Token(0);
145const TIMER_TOKEN: Token = Token(1);
146const STOP_TOKEN: Token = Token(2);
147const POLL_TIMEOUT: Duration = Duration::from_secs(5);
148const PROC_MOUNTS_PATH: &str = "/proc/mounts";
149
150struct State<F: FnMut(MountEvent) -> WatchControl> {
151 known_mounts: HashSet<LinuxMount>,
152 callback: F,
153 coalesce_timer: Option<TimerFd>,
154 coalescing: bool,
155}
156
157impl<F: FnMut(MountEvent) -> WatchControl> State<F> {
158 fn new(callback: F) -> Self {
159 Self {
160 known_mounts: HashSet::with_capacity(8),
161 callback,
162 coalesce_timer: None,
163 coalescing: false,
164 }
165 }
166
167 fn check_diff(
168 &mut self,
169 file: &mut File,
170 coalesced: bool,
171 initial: bool,
172 ) -> Result<WatchControl, ReadError> {
173 debug_assert!(
174 !(coalesced && !self.coalescing),
175 "inconsistent state: coalescing flag should be set before setting the trigger up"
176 );
177 if self.coalescing {
178 if coalesced {
179 self.coalescing = false;
181 } else {
182 return Ok(WatchControl::Continue);
184 }
185 }
186
187 let mounts = read_proc_mounts(file)?;
188 let mounts = HashSet::from_iter(mounts);
189 let unmounted: Vec<&LinuxMount> = self.known_mounts.difference(&mounts).collect();
190 let mounted: Vec<&LinuxMount> = mounts.difference(&self.known_mounts).collect();
191 log::trace!("known_mounts: {:?}", self.known_mounts);
192 log::trace!("curr. mounts: {:?}", mounts);
193
194 if mounted.is_empty() && unmounted.is_empty() {
195 log::warn!("nothing changed");
199 return Ok(WatchControl::Continue);
200 }
201
202 let event = MountEvent {
204 mounted: mounted.into_iter().cloned().collect(),
205 unmounted: unmounted.into_iter().cloned().collect(),
206 coalesced,
207 initial,
208 };
209 let res = (self.callback)(event);
210 if !matches!(res, WatchControl::Coalesce { .. }) {
211 self.known_mounts = mounts;
215 }
216 Ok(res)
218 }
219
220 fn start_coalescing(&mut self, delay: Duration, poll: &Poll) -> Result<(), ErrorImpl> {
221 log::trace!("start coalescing for {delay:?}");
222 let mut register = false;
223 if self.coalesce_timer.is_none() {
224 self.coalesce_timer = Some(TimerFd::new().map_err(|e| ErrorImpl::Timerfd(delay, e))?);
226 register = true;
227 log::trace!("timerfd created");
228 }
229
230 let timer = self.coalesce_timer.as_mut().unwrap();
232 timer.set_state(
233 timerfd::TimerState::Oneshot(delay),
234 timerfd::SetTimeFlags::Default,
235 );
236
237 if register {
239 let fd = timer.as_raw_fd();
240 let mut source = SourceFd(&fd);
241 poll.registry()
242 .register(&mut source, TIMER_TOKEN, Interest::READABLE)
243 .map_err(ErrorImpl::PollTimer)?;
244 log::trace!("timerfd registered");
245 }
246 self.coalescing = true;
248 Ok(())
249 }
250}
251
252fn watch_mounts<F: FnMut(MountEvent) -> WatchControl + Send + 'static>(
254 callback: F,
255) -> Result<MountWatcher, ErrorImpl> {
256 let mut file =
258 File::open(PROC_MOUNTS_PATH).map_err(|e| ErrorImpl::MountRead(ReadError::Io(e)))?;
259 let fd = file.as_raw_fd();
260 let mut fd = SourceFd(&fd);
261
262 let mut poll = Poll::new().map_err(ErrorImpl::PollInit)?;
264
265 let stop_waker = Waker::new(poll.registry(), STOP_TOKEN).map_err(ErrorImpl::PollInit)?;
267
268 poll.registry()
271 .register(&mut fd, MOUNT_TOKEN, Interest::PRIORITY)
272 .map_err(ErrorImpl::PollInit)?;
273
274 let poll_loop = move || -> Result<(), ErrorImpl> {
276 let mut events = Events::with_capacity(8); let mut state = State::new(callback);
278
279 match state.check_diff(&mut file, false, true)? {
282 WatchControl::Continue => (),
283 WatchControl::Stop => return Ok(()),
284 WatchControl::Coalesce { delay } => {
285 state.start_coalescing(delay, &poll)?;
286 }
287 }
288
289 loop {
290 let poll_res = poll.poll(&mut events, Some(POLL_TIMEOUT));
291 if let Err(e) = poll_res {
292 if e.kind() == ErrorKind::Interrupted {
293 continue; } else {
295 return Err(ErrorImpl::PollPoll(e)); }
297 }
298
299 if let Some(event) = events.iter().next() {
302 log::debug!("event on /proc/mounts: {event:?}");
303
304 if event.token() == STOP_TOKEN {
306 break; }
308
309 let coalesced = event.token() == TIMER_TOKEN;
311 match state.check_diff(&mut file, coalesced, false)? {
312 WatchControl::Continue => (),
313 WatchControl::Stop => break,
314 WatchControl::Coalesce { delay } => {
315 state.start_coalescing(delay, &poll)?;
316 }
317 }
318 }
319 }
320 Ok(())
321 };
322
323 let thread_handle = std::thread::spawn(move || {
325 if let Err(e) = poll_loop() {
326 log::error!("error in polling loop: {e:?}");
327 }
328 });
329
330 Ok(MountWatcher {
332 thread_handle: Some(thread_handle),
333 stop_waker: Arc::new(stop_waker),
334 })
335}