1use std::{
4 collections::HashSet,
5 fs::File,
6 io::ErrorKind,
7 os::fd::AsRawFd,
8 sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc,
11 },
12 thread::JoinHandle,
13 time::Duration,
14};
15
16use mio::{unix::SourceFd, Events, Interest, Poll, Token};
17use thiserror::Error;
18use timerfd::TimerFd;
19
20use crate::mount::ReadError;
21
22use super::mount::{read_proc_mounts, LinuxMount};
23
24pub struct MountWatcher {
54 thread_handle: Option<JoinHandle<()>>,
55 stop_flag: Arc<AtomicBool>,
56}
57
58#[derive(Debug, Error)]
60#[error("MountWatcher setup error")]
61pub struct SetupError(#[source] ErrorImpl);
62
63#[derive(Debug, Error)]
65enum ErrorImpl {
66 #[error("read error")]
67 MountRead(#[from] ReadError),
68 #[error("failed to initialize epoll")]
69 PollInit(#[source] std::io::Error),
70 #[error("poll.poll() returned an error")]
71 PollPoll(#[source] std::io::Error),
72 #[error("failed to register a timer to epoll")]
73 PollTimer(#[source] std::io::Error),
74 #[error("could not set up a timer with delay {0:?} for event coalescing")]
75 Timerfd(Duration, #[source] std::io::Error),
76}
77
78impl MountWatcher {
79 pub fn new(
81 callback: impl FnMut(MountEvent) -> WatchControl + Send + 'static,
82 ) -> Result<Self, SetupError> {
83 watch_mounts(callback).map_err(SetupError)
84 }
85
86 pub fn stop(&self) {
90 self.stop_flag.store(true, Ordering::Relaxed);
91 }
92
93 pub fn join(mut self) -> std::thread::Result<()> {
100 self.thread_handle.take().unwrap().join()
101 }
102}
103
104impl Drop for MountWatcher {
105 fn drop(&mut self) {
106 if self.thread_handle.is_some() {
107 self.stop();
108 }
109 }
110}
111
112pub struct MountEvent {
114 pub mounted: Vec<LinuxMount>,
116
117 pub unmounted: Vec<LinuxMount>,
119
120 pub coalesced: bool,
124
125 pub initial: bool,
128}
129
130pub enum WatchControl {
132 Continue,
134 Stop,
136 Coalesce { delay: Duration },
141}
142
143const MOUNT_TOKEN: Token = Token(0);
144const TIMER_TOKEN: Token = Token(1);
145const POLL_TIMEOUT: Duration = Duration::from_secs(5);
146const PROC_MOUNTS_PATH: &str = "/proc/mounts";
147
148struct State<F: FnMut(MountEvent) -> WatchControl> {
149 known_mounts: HashSet<LinuxMount>,
150 callback: F,
151 coalesce_timer: Option<TimerFd>,
152 coalescing: bool,
153}
154
155impl<F: FnMut(MountEvent) -> WatchControl> State<F> {
156 fn new(callback: F) -> Self {
157 Self {
158 known_mounts: HashSet::with_capacity(8),
159 callback,
160 coalesce_timer: None,
161 coalescing: false,
162 }
163 }
164
165 fn check_diff(
166 &mut self,
167 file: &mut File,
168 coalesced: bool,
169 initial: bool,
170 ) -> Result<WatchControl, ReadError> {
171 debug_assert!(
172 !(coalesced && !self.coalescing),
173 "inconsistent state: coalescing flag should be set before setting the trigger up"
174 );
175 if self.coalescing {
176 if coalesced {
177 self.coalescing = false;
179 } else {
180 return Ok(WatchControl::Continue);
182 }
183 }
184
185 let mounts = read_proc_mounts(file)?;
186 let mounts = HashSet::from_iter(mounts);
187 let unmounted: Vec<&LinuxMount> = self.known_mounts.difference(&mounts).collect();
188 let mounted: Vec<&LinuxMount> = mounts.difference(&self.known_mounts).collect();
189 log::trace!("known_mounts: {:?}", self.known_mounts);
190 log::trace!("curr. mounts: {:?}", mounts);
191
192 if mounted.is_empty() && unmounted.is_empty() {
193 log::warn!("nothing changed");
197 return Ok(WatchControl::Continue);
198 }
199
200 let event = MountEvent {
202 mounted: mounted.into_iter().cloned().collect(),
203 unmounted: unmounted.into_iter().cloned().collect(),
204 coalesced,
205 initial,
206 };
207 let res = (self.callback)(event);
208 if !matches!(res, WatchControl::Coalesce { .. }) {
209 self.known_mounts = mounts;
213 }
214 Ok(res)
216 }
217
218 fn start_coalescing(&mut self, delay: Duration, poll: &Poll) -> Result<(), ErrorImpl> {
219 log::trace!("start coalescing for {delay:?}");
220 let mut register = false;
221 if self.coalesce_timer.is_none() {
222 self.coalesce_timer = Some(TimerFd::new().map_err(|e| ErrorImpl::Timerfd(delay, e))?);
224 register = true;
225 log::trace!("timerfd created");
226 }
227
228 let timer = self.coalesce_timer.as_mut().unwrap();
230 timer.set_state(
231 timerfd::TimerState::Oneshot(delay),
232 timerfd::SetTimeFlags::Default,
233 );
234
235 if register {
237 let fd = timer.as_raw_fd();
238 let mut source = SourceFd(&fd);
239 poll.registry()
240 .register(&mut source, TIMER_TOKEN, Interest::READABLE)
241 .map_err(ErrorImpl::PollTimer)?;
242 log::trace!("timerfd registered");
243 }
244 self.coalescing = true;
246 Ok(())
247 }
248}
249
250fn watch_mounts<F: FnMut(MountEvent) -> WatchControl + Send + 'static>(
252 callback: F,
253) -> Result<MountWatcher, ErrorImpl> {
254 let mut file =
256 File::open(PROC_MOUNTS_PATH).map_err(|e| ErrorImpl::MountRead(ReadError::Io(e)))?;
257 let fd = file.as_raw_fd();
258 let mut fd = SourceFd(&fd);
259
260 let mut poll = Poll::new().map_err(|e| ErrorImpl::PollInit(e))?;
264 poll.registry()
265 .register(&mut fd, MOUNT_TOKEN, Interest::PRIORITY)
266 .map_err(|e| ErrorImpl::PollInit(e))?;
267
268 let stop_flag = Arc::new(AtomicBool::new(false));
270 let stop_flag_thread = stop_flag.clone();
271
272 let poll_loop = move || -> Result<(), ErrorImpl> {
274 let mut events = Events::with_capacity(8); let mut state = State::new(callback);
276
277 match state.check_diff(&mut file, false, true)? {
280 WatchControl::Continue => (),
281 WatchControl::Stop => return Ok(()),
282 WatchControl::Coalesce { delay } => {
283 state.start_coalescing(delay, &poll)?;
284 }
285 }
286
287 loop {
288 let poll_res = poll.poll(&mut events, Some(POLL_TIMEOUT));
289 if let Err(e) = poll_res {
290 if e.kind() == ErrorKind::Interrupted {
291 continue; } else {
293 return Err(ErrorImpl::PollPoll(e)); }
295 }
296
297 if let Some(event) = events.iter().next() {
300 log::debug!("event on /proc/mounts: {event:?}");
301
302 let coalesced = dbg!(event.token() == TIMER_TOKEN);
304 match state.check_diff(&mut file, coalesced, false)? {
305 WatchControl::Continue => (),
306 WatchControl::Stop => break,
307 WatchControl::Coalesce { delay } => {
308 state.start_coalescing(delay, &poll)?;
309 }
310 }
311 }
312 if stop_flag_thread.load(Ordering::Relaxed) {
313 break;
314 }
315 }
316 Ok(())
317 };
318
319 let thread_handle = std::thread::spawn(move || {
321 if let Err(e) = poll_loop() {
322 log::error!("error in poll loop: {e:?}");
323 }
324 });
325
326 Ok(MountWatcher {
328 thread_handle: Some(thread_handle),
329 stop_flag,
330 })
331}