use std::{
collections::HashSet, fs::File, io::ErrorKind, os::fd::AsRawFd, sync::Arc, thread::JoinHandle,
time::Duration,
};
use mio::{unix::SourceFd, Events, Interest, Poll, Token, Waker};
use thiserror::Error;
use timerfd::TimerFd;
use crate::mount::ReadError;
use super::mount::{read_proc_mounts, LinuxMount};
pub struct MountWatcher {
thread_handle: Option<JoinHandle<()>>,
stop_waker: Arc<Waker>,
}
#[derive(Debug, Error)]
#[error("MountWatcher setup error")]
pub struct SetupError(#[source] ErrorImpl);
#[derive(Debug, Error)]
#[error("MountWatcher stop error")]
pub struct StopError(#[source] ErrorImpl);
#[derive(Debug, Error)]
enum ErrorImpl {
#[error("read error")]
MountRead(#[from] ReadError),
#[error("failed to initialize epoll")]
PollInit(#[source] std::io::Error),
#[error("poll.poll() returned an error")]
PollPoll(#[source] std::io::Error),
#[error("failed to register a timer to epoll")]
PollTimer(#[source] std::io::Error),
#[error("could not set up a timer with delay {0:?} for event coalescing")]
Timerfd(Duration, #[source] std::io::Error),
#[error("failed to stop epoll from another thread")]
Stop(#[source] std::io::Error),
}
impl MountWatcher {
pub fn new(
callback: impl FnMut(MountEvent) -> WatchControl + Send + 'static,
) -> Result<Self, SetupError> {
watch_mounts(callback).map_err(SetupError)
}
pub fn stop(&self) -> Result<(), StopError> {
self.stop_waker
.wake()
.map_err(|e| StopError(ErrorImpl::Stop(e)))
}
pub fn join(mut self) -> std::thread::Result<()> {
self.thread_handle.take().unwrap().join()
}
}
impl Drop for MountWatcher {
fn drop(&mut self) {
if self.thread_handle.is_some() {
let _ = self.stop();
}
}
}
pub struct MountEvent {
pub mounted: Vec<LinuxMount>,
pub unmounted: Vec<LinuxMount>,
pub coalesced: bool,
pub initial: bool,
}
pub enum WatchControl {
Continue,
Stop,
Coalesce { delay: Duration },
}
const MOUNT_TOKEN: Token = Token(0);
const TIMER_TOKEN: Token = Token(1);
const STOP_TOKEN: Token = Token(2);
const POLL_TIMEOUT: Duration = Duration::from_secs(5);
const PROC_MOUNTS_PATH: &str = "/proc/mounts";
struct State<F: FnMut(MountEvent) -> WatchControl> {
known_mounts: HashSet<LinuxMount>,
callback: F,
coalesce_timer: Option<TimerFd>,
coalescing: bool,
}
impl<F: FnMut(MountEvent) -> WatchControl> State<F> {
fn new(callback: F) -> Self {
Self {
known_mounts: HashSet::with_capacity(8),
callback,
coalesce_timer: None,
coalescing: false,
}
}
fn check_diff(
&mut self,
file: &mut File,
coalesced: bool,
initial: bool,
) -> Result<WatchControl, ReadError> {
debug_assert!(
!(coalesced && !self.coalescing),
"inconsistent state: coalescing flag should be set before setting the trigger up"
);
if self.coalescing {
if coalesced {
self.coalescing = false;
} else {
return Ok(WatchControl::Continue);
}
}
let mounts = read_proc_mounts(file)?;
let mounts = HashSet::from_iter(mounts);
let unmounted: Vec<&LinuxMount> = self.known_mounts.difference(&mounts).collect();
let mounted: Vec<&LinuxMount> = mounts.difference(&self.known_mounts).collect();
log::trace!("known_mounts: {:?}", self.known_mounts);
log::trace!("curr. mounts: {:?}", mounts);
if mounted.is_empty() && unmounted.is_empty() {
log::warn!("nothing changed");
return Ok(WatchControl::Continue);
}
let event = MountEvent {
mounted: mounted.into_iter().cloned().collect(),
unmounted: unmounted.into_iter().cloned().collect(),
coalesced,
initial,
};
let res = (self.callback)(event);
if !matches!(res, WatchControl::Coalesce { .. }) {
self.known_mounts = mounts;
}
Ok(res)
}
fn start_coalescing(&mut self, delay: Duration, poll: &Poll) -> Result<(), ErrorImpl> {
log::trace!("start coalescing for {delay:?}");
let mut register = false;
if self.coalesce_timer.is_none() {
self.coalesce_timer = Some(TimerFd::new().map_err(|e| ErrorImpl::Timerfd(delay, e))?);
register = true;
log::trace!("timerfd created");
}
let timer = self.coalesce_timer.as_mut().unwrap();
timer.set_state(
timerfd::TimerState::Oneshot(delay),
timerfd::SetTimeFlags::Default,
);
if register {
let fd = timer.as_raw_fd();
let mut source = SourceFd(&fd);
poll.registry()
.register(&mut source, TIMER_TOKEN, Interest::READABLE)
.map_err(ErrorImpl::PollTimer)?;
log::trace!("timerfd registered");
}
self.coalescing = true;
Ok(())
}
}
fn watch_mounts<F: FnMut(MountEvent) -> WatchControl + Send + 'static>(
callback: F,
) -> Result<MountWatcher, ErrorImpl> {
let mut file =
File::open(PROC_MOUNTS_PATH).map_err(|e| ErrorImpl::MountRead(ReadError::Io(e)))?;
let fd = file.as_raw_fd();
let mut fd = SourceFd(&fd);
let mut poll = Poll::new().map_err(ErrorImpl::PollInit)?;
let stop_waker = Waker::new(poll.registry(), STOP_TOKEN).map_err(ErrorImpl::PollInit)?;
poll.registry()
.register(&mut fd, MOUNT_TOKEN, Interest::PRIORITY)
.map_err(ErrorImpl::PollInit)?;
let poll_loop = move || -> Result<(), ErrorImpl> {
let mut events = Events::with_capacity(8); let mut state = State::new(callback);
match state.check_diff(&mut file, false, true)? {
WatchControl::Continue => (),
WatchControl::Stop => return Ok(()),
WatchControl::Coalesce { delay } => {
state.start_coalescing(delay, &poll)?;
}
}
loop {
let poll_res = poll.poll(&mut events, Some(POLL_TIMEOUT));
if let Err(e) = poll_res {
if e.kind() == ErrorKind::Interrupted {
continue; } else {
return Err(ErrorImpl::PollPoll(e)); }
}
if let Some(event) = events.iter().next() {
log::debug!("event on /proc/mounts: {event:?}");
if event.token() == STOP_TOKEN {
break; }
let coalesced = event.token() == TIMER_TOKEN;
match state.check_diff(&mut file, coalesced, false)? {
WatchControl::Continue => (),
WatchControl::Stop => break,
WatchControl::Coalesce { delay } => {
state.start_coalescing(delay, &poll)?;
}
}
}
}
Ok(())
};
let thread_handle = std::thread::spawn(move || {
if let Err(e) = poll_loop() {
log::error!("error in polling loop: {e:?}");
}
});
Ok(MountWatcher {
thread_handle: Some(thread_handle),
stop_waker: Arc::new(stop_waker),
})
}