#![cfg(target_os = "linux")]
use std::collections::{HashMap, HashSet};
use std::ffi::CString;
use std::fs::File;
use std::io::{self, Read, Seek};
use std::os::fd::{AsFd, BorrowedFd};
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::{Arc, Mutex, RwLock, Weak};
use tracing::debug;
use super::MOUNT_INFO_FILE;
use super::statx::statx;
use super::util::{einval, is_safe_inode};
pub type MountId = u64;
pub struct MountFd {
file: File,
mount_id: MountId,
map: Weak<RwLock<HashMap<MountId, Weak<MountFd>>>>,
}
impl AsFd for MountFd {
fn as_fd(&'_ self) -> BorrowedFd<'_> {
self.file.as_fd()
}
}
impl Drop for MountFd {
fn drop(&mut self) {
debug!(
"Dropping MountFd: mount_id={}, mount_fd={}",
self.mount_id,
self.file.as_raw_fd(),
);
if let Some(map) = self.map.upgrade() {
let mut map = map.write().unwrap();
if let Some(0) = map.get(&self.mount_id).map(Weak::strong_count) {
map.remove(&self.mount_id);
}
}
}
}
pub struct MountFds {
map: Arc<RwLock<HashMap<MountId, Weak<MountFd>>>>,
mount_info: Mutex<File>,
mount_prefix: Option<String>,
error_logged: Arc<RwLock<HashSet<MountId>>>,
}
impl MountFds {
pub fn new(mount_prefix: Option<String>) -> io::Result<Self> {
let mount_info_file = File::open(MOUNT_INFO_FILE)?;
Ok(Self::with_mount_info_file(mount_info_file, mount_prefix))
}
pub fn with_mount_info_file(mount_info: File, mount_prefix: Option<String>) -> Self {
MountFds {
map: Default::default(),
mount_info: Mutex::new(mount_info),
mount_prefix,
error_logged: Default::default(),
}
}
pub fn get<F>(&self, mount_id: MountId, reopen_fd: F) -> MPRResult<Arc<MountFd>>
where
F: FnOnce(RawFd, libc::c_int, u32) -> io::Result<File>,
{
let existing_mount_fd = self
.map
.read()
.unwrap()
.get(&mount_id)
.and_then(Weak::upgrade);
let mount_fd = if let Some(mount_fd) = existing_mount_fd {
mount_fd
} else {
let mount_point = self.get_mount_root(mount_id)?;
let c_mount_point = CString::new(mount_point.clone()).map_err(|e| {
self.error_for(mount_id, e)
.prefix(format!("Failed to convert \"{mount_point}\" to a CString"))
})?;
let mount_point_fd = unsafe { libc::open(c_mount_point.as_ptr(), libc::O_PATH) };
if mount_point_fd < 0 {
return Err(self
.error_for(mount_id, io::Error::last_os_error())
.prefix(format!("Failed to open mount point \"{mount_point}\"")));
}
let st_mode = self.validate_mount_id(mount_id, &mount_point_fd, &mount_point)?;
let file_type = st_mode & (libc::S_IFMT as libc::mode_t);
if !is_safe_inode(file_type as u32) {
return Err(self
.error_for(mount_id, io::Error::from_raw_os_error(libc::EIO))
.set_desc(format!(
"Mount point \"{mount_point}\" is not a regular file or directory"
)));
}
let file = reopen_fd(
mount_point_fd.as_raw_fd(),
libc::O_RDONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC,
st_mode as u32,
)
.map_err(|e| {
self.error_for(mount_id, e).prefix(format!(
"Failed to reopen mount point \"{mount_point}\" for reading"
))
})?;
let mut mount_fds_locked = self.map.write().unwrap();
if let Some(mount_fd) = mount_fds_locked.get(&mount_id).and_then(Weak::upgrade) {
mount_fd
} else {
debug!(
"Creating MountFd: mount_id={}, mount_fd={}",
mount_id,
file.as_raw_fd(),
);
let mount_fd = Arc::new(MountFd {
file,
mount_id,
map: Arc::downgrade(&self.map),
});
mount_fds_locked.insert(mount_id, Arc::downgrade(&mount_fd));
mount_fd
}
};
Ok(mount_fd)
}
fn validate_mount_id(
&self,
mount_id: MountId,
mount_point_fd: &impl AsRawFd,
mount_point: &str,
) -> MPRResult<libc::mode_t> {
let stx = statx(mount_point_fd, None).map_err(|e| {
self.error_for(mount_id, e)
.prefix(format!("Failed to stat mount point \"{mount_point}\""))
})?;
if stx.mnt_id != mount_id {
return Err(self
.error_for(mount_id, io::Error::from_raw_os_error(libc::EIO))
.set_desc(format!(
"Mount point's ({}) mount ID ({}) does not match expected value ({})",
mount_point, stx.mnt_id, mount_id
)));
}
Ok(stx.st.st_mode)
}
fn get_mount_root(&self, mount_id: MountId) -> MPRResult<String> {
let mountinfo = {
let mountinfo_file = &mut *self.mount_info.lock().unwrap();
mountinfo_file.rewind().map_err(|e| {
self.error_for_nolookup(mount_id, e)
.prefix("Failed to access /proc/self/mountinfo".into())
})?;
let mut mountinfo = String::new();
mountinfo_file.read_to_string(&mut mountinfo).map_err(|e| {
self.error_for_nolookup(mount_id, e)
.prefix("Failed to read /proc/self/mountinfo".into())
})?;
mountinfo
};
let path = mountinfo.split('\n').find_map(|line| {
let mut columns = line.split(char::is_whitespace);
if columns.next()?.parse::<MountId>().ok()? != mount_id {
return None;
}
columns.nth(3)
});
match path {
Some(p) => {
let p = String::from(p);
if let Some(prefix) = self.mount_prefix.as_ref() {
if let Some(suffix) = p.strip_prefix(prefix).filter(|s| !s.is_empty()) {
Ok(suffix.into())
} else {
Ok("/".into())
}
} else {
Ok(p)
}
}
None => Err(self
.error_for_nolookup(mount_id, einval())
.set_desc(format!("Failed to find mount root for mount ID {mount_id}"))),
}
}
fn error_for_nolookup<E: ToString + Into<io::Error>>(
&self,
mount_id: MountId,
err: E,
) -> MPRError {
let err = MPRError::from(err).set_mount_id(mount_id);
if self.error_logged.read().unwrap().contains(&mount_id) {
err.silence()
} else {
self.error_logged.write().unwrap().insert(mount_id);
err
}
}
pub fn error_for<E: ToString + Into<io::Error>>(&self, mount_id: MountId, err: E) -> MPRError {
let err = self.error_for_nolookup(mount_id, err);
if err.silent() {
err
} else {
if let Ok(mount_root) = self.get_mount_root(mount_id) {
err.set_mount_root(mount_root)
} else {
err
}
}
}
}
#[derive(Debug)]
pub struct MPRError {
io: io::Error,
description: String,
silent: bool,
fs_mount_id: Option<MountId>,
fs_mount_root: Option<String>,
}
pub type MPRResult<T> = Result<T, MPRError>;
impl<E: ToString + Into<io::Error>> From<E> for MPRError {
fn from(err: E) -> Self {
let description = err.to_string();
MPRError {
io: err.into(),
description,
silent: false,
fs_mount_id: None,
fs_mount_root: None,
}
}
}
impl MPRError {
#[must_use]
pub fn set_desc(mut self, s: String) -> Self {
self.description = s;
self
}
#[must_use]
pub fn prefix(self, s: String) -> Self {
let new_desc = format!("{}: {}", s, self.description);
self.set_desc(new_desc)
}
#[must_use]
fn set_mount_id(mut self, mount_id: MountId) -> Self {
self.fs_mount_id = Some(mount_id);
self
}
#[must_use]
fn set_mount_root(mut self, mount_root: String) -> Self {
self.fs_mount_root = Some(mount_root);
self
}
#[must_use]
fn silence(mut self) -> Self {
self.silent = true;
self
}
pub fn silent(&self) -> bool {
self.silent
}
pub fn into_inner(self) -> io::Error {
self.io
}
}
impl std::fmt::Display for MPRError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (self.fs_mount_id, &self.fs_mount_root) {
(None, None) => write!(f, "{}", self.description),
(Some(id), None) => write!(f, "Filesystem with mount ID {}: {}", id, self.description),
(None, Some(root)) => write!(
f,
"Filesystem mounted on \"{}\": {}",
root, self.description
),
(Some(id), Some(root)) => write!(
f,
"Filesystem mounted on \"{}\" (mount ID: {}): {}",
root, id, self.description
),
}
}
}
impl std::error::Error for MPRError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::passthrough::file_handle::FileHandle;
#[test]
fn test_mount_fd_get() {
let topdir = std::env::current_dir().unwrap();
let dir = File::open(&topdir).unwrap();
let filename = CString::new("Cargo.toml").unwrap();
let mount_fds = MountFds::new(None).unwrap();
let handle = FileHandle::from_name_at(&dir, &filename).unwrap().unwrap();
let fd1 = mount_fds
.get(handle.mnt_id, |_fd, _flags, _mode| File::open(&topdir))
.unwrap();
assert_eq!(Arc::strong_count(&fd1), 1);
assert_eq!(mount_fds.map.read().unwrap().len(), 1);
let fd2 = mount_fds
.get(handle.mnt_id, |_fd, _flags, _mode| File::open(&topdir))
.unwrap();
assert_eq!(Arc::strong_count(&fd2), 2);
assert_eq!(mount_fds.map.read().unwrap().len(), 1);
assert_eq!(fd1.as_fd().as_raw_fd(), fd2.as_fd().as_raw_fd());
drop(fd1);
assert_eq!(Arc::strong_count(&fd2), 1);
assert_eq!(mount_fds.map.read().unwrap().len(), 1);
drop(fd2);
assert_eq!(mount_fds.map.read().unwrap().len(), 0);
}
#[test]
fn test_mpr_error() {
let io_error = io::Error::other("test");
let mpr_error = MPRError::from(io_error);
assert!(!mpr_error.silent);
assert!(mpr_error.fs_mount_id.is_none());
assert!(mpr_error.fs_mount_root.is_none());
let mpr_error = mpr_error.silence();
let msg = format!("{mpr_error}");
assert!(!msg.is_empty());
assert!(mpr_error.silent());
}
}