1use std::{
4 collections::HashSet,
5 fs::File,
6 io::ErrorKind,
7 os::fd::AsRawFd,
8 sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc,
11 },
12 thread::JoinHandle,
13 time::Duration,
14};
15
16use mio::{unix::SourceFd, Events, Interest, Poll, Token};
17use thiserror::Error;
18use timerfd::TimerFd;
19
20use crate::mount::ReadError;
21
22use super::mount::{read_proc_mounts, LinuxMount};
23
24pub struct MountWatcher {
54 thread_handle: Option<JoinHandle<()>>,
55 stop_flag: Arc<AtomicBool>,
56}
57
58#[derive(Debug, Error)]
60#[error("MountWatcher setup error")]
61pub struct SetupError(#[source] ErrorImpl);
62
63#[derive(Debug, Error)]
65enum ErrorImpl {
66 #[error("read error")]
67 MountRead(#[from] ReadError),
68 #[error("failed to initialize epoll")]
69 PollInit(#[source] std::io::Error),
70 #[error("poll.poll() returned an error")]
71 PollPoll(#[source] std::io::Error),
72 #[error("failed to register a timer to epoll")]
73 PollTimer(#[source] std::io::Error),
74 #[error("could not set up a timer with delay {0:?} for event coalescing")]
75 Timerfd(Duration, #[source] std::io::Error),
76}
77
78impl MountWatcher {
79 pub fn new(
81 callback: impl FnMut(MountEvent) -> WatchControl + Send + 'static,
82 ) -> Result<Self, SetupError> {
83 watch_mounts(callback).map_err(SetupError)
84 }
85
86 pub fn stop(mut self) -> std::thread::Result<()> {
91 self.stop_flag.store(true, Ordering::Relaxed);
92 self.thread_handle.take().unwrap().join()
93 }
94
95 pub fn join(mut self) -> std::thread::Result<()> {
96 self.thread_handle.take().unwrap().join()
97 }
98}
99
100impl Drop for MountWatcher {
101 fn drop(&mut self) {
102 if self.thread_handle.is_some() {
103 self.stop_flag.store(true, Ordering::Relaxed);
104 }
105 }
106}
107
108pub struct MountEvent {
110 pub mounted: Vec<LinuxMount>,
112
113 pub unmounted: Vec<LinuxMount>,
115
116 pub coalesced: bool,
120
121 pub initial: bool,
124}
125
126pub enum WatchControl {
128 Continue,
130 Stop,
132 Coalesce { delay: Duration },
137}
138
139const MOUNT_TOKEN: Token = Token(0);
140const TIMER_TOKEN: Token = Token(1);
141const POLL_TIMEOUT: Duration = Duration::from_secs(5);
142const PROC_MOUNTS_PATH: &str = "/proc/mounts";
143
144struct State<F: FnMut(MountEvent) -> WatchControl> {
145 known_mounts: HashSet<LinuxMount>,
146 callback: F,
147 coalesce_timer: Option<TimerFd>,
148 coalescing: bool,
149}
150
151impl<F: FnMut(MountEvent) -> WatchControl> State<F> {
152 fn new(callback: F) -> Self {
153 Self {
154 known_mounts: HashSet::with_capacity(8),
155 callback,
156 coalesce_timer: None,
157 coalescing: false,
158 }
159 }
160
161 fn check_diff(
162 &mut self,
163 file: &mut File,
164 coalesced: bool,
165 initial: bool,
166 ) -> Result<WatchControl, ReadError> {
167 debug_assert!(
168 !(coalesced && !self.coalescing),
169 "inconsistent state: coalescing flag should be set before setting the trigger up"
170 );
171 if self.coalescing {
172 if coalesced {
173 self.coalescing = false;
175 } else {
176 return Ok(WatchControl::Continue);
178 }
179 }
180
181 let mounts = read_proc_mounts(file)?;
182 let mounts = HashSet::from_iter(mounts);
183 let unmounted: Vec<&LinuxMount> = self.known_mounts.difference(&mounts).collect();
184 let mounted: Vec<&LinuxMount> = mounts.difference(&self.known_mounts).collect();
185 log::trace!("known_mounts: {:?}", self.known_mounts);
186 log::trace!("curr. mounts: {:?}", mounts);
187
188 if mounted.is_empty() && unmounted.is_empty() {
189 log::warn!("nothing changed");
193 return Ok(WatchControl::Continue);
194 }
195
196 let event = MountEvent {
198 mounted: mounted.into_iter().cloned().collect(),
199 unmounted: unmounted.into_iter().cloned().collect(),
200 coalesced,
201 initial,
202 };
203 let res = (self.callback)(event);
204 if !matches!(res, WatchControl::Coalesce { .. }) {
205 self.known_mounts = mounts;
209 }
210 Ok(res)
212 }
213
214 fn start_coalescing(&mut self, delay: Duration, poll: &Poll) -> Result<(), ErrorImpl> {
215 log::trace!("start coalescing for {delay:?}");
216 let mut register = false;
217 if self.coalesce_timer.is_none() {
218 self.coalesce_timer = Some(TimerFd::new().map_err(|e| ErrorImpl::Timerfd(delay, e))?);
220 register = true;
221 log::trace!("timerfd created");
222 }
223
224 let timer = self.coalesce_timer.as_mut().unwrap();
226 timer.set_state(
227 timerfd::TimerState::Oneshot(delay),
228 timerfd::SetTimeFlags::Default,
229 );
230
231 if register {
233 let fd = timer.as_raw_fd();
234 let mut source = SourceFd(&fd);
235 poll.registry()
236 .register(&mut source, TIMER_TOKEN, Interest::READABLE)
237 .map_err(ErrorImpl::PollTimer)?;
238 log::trace!("timerfd registered");
239 }
240 self.coalescing = true;
242 Ok(())
243 }
244}
245
246fn watch_mounts<F: FnMut(MountEvent) -> WatchControl + Send + 'static>(
248 callback: F,
249) -> Result<MountWatcher, ErrorImpl> {
250 let mut file =
252 File::open(PROC_MOUNTS_PATH).map_err(|e| ErrorImpl::MountRead(ReadError::Io(e)))?;
253 let fd = file.as_raw_fd();
254 let mut fd = SourceFd(&fd);
255
256 let mut poll = Poll::new().map_err(|e| ErrorImpl::PollInit(e))?;
260 poll.registry()
261 .register(&mut fd, MOUNT_TOKEN, Interest::PRIORITY)
262 .map_err(|e| ErrorImpl::PollInit(e))?;
263
264 let stop_flag = Arc::new(AtomicBool::new(false));
266 let stop_flag_thread = stop_flag.clone();
267
268 let poll_loop = move || -> Result<(), ErrorImpl> {
270 let mut events = Events::with_capacity(8); let mut state = State::new(callback);
272
273 match state.check_diff(&mut file, false, true)? {
276 WatchControl::Continue => (),
277 WatchControl::Stop => return Ok(()),
278 WatchControl::Coalesce { delay } => {
279 state.start_coalescing(delay, &poll)?;
280 }
281 }
282
283 loop {
284 let poll_res = poll.poll(&mut events, Some(POLL_TIMEOUT));
285 if let Err(e) = poll_res {
286 if e.kind() == ErrorKind::Interrupted {
287 continue; } else {
289 return Err(ErrorImpl::PollPoll(e)); }
291 }
292
293 if let Some(event) = events.iter().next() {
296 log::debug!("event on /proc/mounts: {event:?}");
297
298 let coalesced = dbg!(event.token() == TIMER_TOKEN);
300 match state.check_diff(&mut file, coalesced, false)? {
301 WatchControl::Continue => (),
302 WatchControl::Stop => break,
303 WatchControl::Coalesce { delay } => {
304 state.start_coalescing(delay, &poll)?;
305 }
306 }
307 }
308 if stop_flag_thread.load(Ordering::Relaxed) {
309 break;
310 }
311 }
312 Ok(())
313 };
314
315 let thread_handle = std::thread::spawn(move || {
317 if let Err(e) = poll_loop() {
318 log::error!("error in poll loop: {e:?}");
319 }
320 });
321
322 Ok(MountWatcher {
324 thread_handle: Some(thread_handle),
325 stop_flag,
326 })
327}