use crate::passthrough::stat::{statx, MountId};
use crate::passthrough::util::openat;
use crate::util::ResultErrorContext;
use std::collections::{HashMap, HashSet};
use std::ffi::CString;
use std::fs::File;
use std::io::{self, Read, Seek};
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::sync::{Arc, Mutex, RwLock, Weak};
pub struct MountFd {
map: Weak<RwLock<HashMap<MountId, Weak<MountFd>>>>,
mount_id: MountId,
file: File,
}
pub struct MountFds {
map: Arc<RwLock<HashMap<MountId, Weak<MountFd>>>>,
mountinfo: Mutex<File>,
mountprefix: Option<String>,
error_logged: Arc<RwLock<HashSet<MountId>>>,
}
impl MountFd {
pub fn new<D: AsRawFd>(
mount_fds: Option<&MountFds>,
dir: &D,
path: &str,
) -> io::Result<Arc<MountFd>> {
let file =
openat(dir, path, libc::O_RDONLY).err_context(|| format!("Failed to open {path}"))?;
let st = statx(&file, None).err_context(|| format!("Failed to get {path}'s mount ID"))?;
if let Some(mount_fds) = mount_fds {
let mut mfds_locked = mount_fds.map.write().unwrap();
if let Some(mount_fd) = mfds_locked.get(&st.mnt_id).and_then(Weak::upgrade) {
return Ok(mount_fd);
}
let mount_fd = Arc::new(MountFd {
map: Arc::downgrade(&mount_fds.map),
mount_id: st.mnt_id,
file,
});
mfds_locked.insert(st.mnt_id, Arc::downgrade(&mount_fd));
Ok(mount_fd)
} else {
Ok(Arc::new(MountFd {
map: Weak::new(),
mount_id: st.mnt_id,
file,
}))
}
}
pub fn file(&self) -> &File {
&self.file
}
pub fn mount_id(&self) -> MountId {
self.mount_id
}
}
#[derive(Debug)]
pub struct MPRError {
io: io::Error,
description: String,
silent: bool,
fs_mount_id: Option<MountId>,
fs_mount_root: Option<String>,
}
pub type MPRResult<T> = Result<T, MPRError>;
impl Drop for MountFd {
fn drop(&mut self) {
debug!(
"Dropping MountFd: mount_id={}, mount_fd={}",
self.mount_id,
self.file.as_raw_fd(),
);
if let Some(map) = self.map.upgrade() {
let mut map = map.write().unwrap();
if let Some(0) = map.get(&self.mount_id).map(Weak::strong_count) {
map.remove(&self.mount_id);
}
}
}
}
impl<E: ToString + Into<io::Error>> From<E> for MPRError {
fn from(err: E) -> Self {
let description = err.to_string();
MPRError {
io: err.into(),
description,
silent: false,
fs_mount_id: None,
fs_mount_root: None,
}
}
}
impl MPRError {
#[must_use]
pub fn set_desc(mut self, s: String) -> Self {
self.description = s;
self
}
#[must_use]
pub fn prefix(self, s: String) -> Self {
let new_desc = format!("{s}: {}", self.description);
self.set_desc(new_desc)
}
#[must_use]
fn set_mount_id(mut self, mount_id: MountId) -> Self {
self.fs_mount_id = Some(mount_id);
self
}
#[must_use]
fn set_mount_root(mut self, mount_root: String) -> Self {
self.fs_mount_root = Some(mount_root);
self
}
#[must_use]
fn silence(mut self) -> Self {
self.silent = true;
self
}
pub fn silent(&self) -> bool {
self.silent
}
pub fn into_inner(self) -> io::Error {
self.io
}
}
impl std::fmt::Display for MPRError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (self.fs_mount_id, &self.fs_mount_root) {
(None, None) => write!(f, "{}", self.description),
(Some(id), None) => write!(f, "Filesystem with mount ID {id}: {}", self.description),
(None, Some(root)) => {
write!(f, "Filesystem mounted on \"{root}\": {}", self.description)
}
(Some(id), Some(root)) => write!(
f,
"Filesystem mounted on \"{root}\" (mount ID: {id}): {}",
self.description
),
}
}
}
impl std::error::Error for MPRError {}
impl MountFds {
pub fn new(mountinfo: File, mountprefix: Option<String>) -> Self {
MountFds {
map: Default::default(),
mountinfo: Mutex::new(mountinfo),
mountprefix,
error_logged: Default::default(),
}
}
pub fn get<F>(&self, mount_id: MountId, reopen_fd: F) -> MPRResult<Arc<MountFd>>
where
F: FnOnce(RawFd, libc::c_int) -> io::Result<File>,
{
let existing_mount_fd = self
.map
.read()
.unwrap()
.get(&mount_id)
.and_then(Weak::upgrade);
let mount_fd = if let Some(mount_fd) = existing_mount_fd {
mount_fd
} else {
let mount_point = self.get_mount_root(mount_id)?;
let c_mount_point = CString::new(mount_point.clone()).map_err(|e| {
self.error_for(mount_id, e)
.prefix(format!("Failed to convert \"{mount_point}\" to a CString"))
})?;
let mount_point_fd = unsafe { libc::open(c_mount_point.as_ptr(), libc::O_PATH) };
if mount_point_fd < 0 {
return Err(self
.error_for(mount_id, io::Error::last_os_error())
.prefix(format!("Failed to open mount point \"{mount_point}\"")));
}
let mount_point_path = unsafe { File::from_raw_fd(mount_point_fd) };
let stx = statx(&mount_point_path, None).map_err(|e| {
self.error_for(mount_id, e)
.prefix(format!("Failed to stat mount point \"{mount_point}\""))
})?;
if stx.mnt_id != mount_id {
return Err(self
.error_for(mount_id, io::Error::from_raw_os_error(libc::EIO))
.set_desc(format!(
"Mount point's ({mount_point}) mount ID ({}) does not match expected value ({mount_id})",
stx.mnt_id
)));
}
let file_type = stx.st.st_mode & libc::S_IFMT;
if file_type != libc::S_IFREG && file_type != libc::S_IFDIR {
return Err(self
.error_for(mount_id, io::Error::from_raw_os_error(libc::EIO))
.set_desc(format!(
"Mount point \"{mount_point}\" is not a regular file or directory"
)));
}
let file = reopen_fd(
mount_point_path.as_raw_fd(),
libc::O_RDONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC,
)
.map_err(|e| {
self.error_for(mount_id, e).prefix(format!(
"Failed to reopen mount point \"{mount_point}\" for reading"
))
})?;
let mut mount_fds_locked = self.map.write().unwrap();
if let Some(mount_fd) = mount_fds_locked.get(&mount_id).and_then(Weak::upgrade) {
mount_fd
} else {
debug!(
"Creating MountFd: mount_id={mount_id}, mount_fd={}",
file.as_raw_fd(),
);
let mount_fd = Arc::new(MountFd {
map: Arc::downgrade(&self.map),
mount_id,
file,
});
mount_fds_locked.insert(mount_id, Arc::downgrade(&mount_fd));
mount_fd
}
};
Ok(mount_fd)
}
pub fn get_mount_root(&self, mount_id: MountId) -> MPRResult<String> {
let mountinfo = {
let mountinfo_file = &mut *self.mountinfo.lock().unwrap();
mountinfo_file.rewind().map_err(|e| {
self.error_for_nolookup(mount_id, e)
.prefix("Failed to access /proc/self/mountinfo".into())
})?;
let mut mountinfo = String::new();
mountinfo_file.read_to_string(&mut mountinfo).map_err(|e| {
self.error_for_nolookup(mount_id, e)
.prefix("Failed to read /proc/self/mountinfo".into())
})?;
mountinfo
};
let path = mountinfo.split('\n').find_map(|line| {
let mut columns = line.split(char::is_whitespace);
if columns.next()?.parse::<MountId>().ok()? != mount_id {
return None;
}
columns.nth(3)
});
match path {
Some(p) => {
let p = String::from(p);
if let Some(prefix) = self.mountprefix.as_ref() {
if let Some(suffix) = p.strip_prefix(prefix).filter(|s| !s.is_empty()) {
Ok(suffix.into())
} else {
Ok("/".into())
}
} else {
Ok(p)
}
}
None => Err(self
.error_for_nolookup(mount_id, io::Error::from_raw_os_error(libc::EINVAL))
.set_desc(format!("Failed to find mount root for mount ID {mount_id}"))),
}
}
fn error_for_nolookup<E: ToString + Into<io::Error>>(
&self,
mount_id: MountId,
err: E,
) -> MPRError {
let err = MPRError::from(err).set_mount_id(mount_id);
if self.error_logged.read().unwrap().contains(&mount_id) {
err.silence()
} else {
self.error_logged.write().unwrap().insert(mount_id);
err
}
}
pub fn error_for<E: ToString + Into<io::Error>>(&self, mount_id: MountId, err: E) -> MPRError {
let err = self.error_for_nolookup(mount_id, err);
if err.silent() {
err
} else {
if let Ok(mount_root) = self.get_mount_root(mount_id) {
err.set_mount_root(mount_root)
} else {
err
}
}
}
}