use std::fs::File;
use std::os::unix::fs::MetadataExt;
use std::os::unix::io::AsRawFd;
use std::path::{Path, PathBuf};
use std::thread::{self, JoinHandle};
use nix::mount::{mount, umount2, MntFlags, MsFlags};
use nix::sched::{setns, unshare, CloneFlags};
use nix::unistd::gettid;
use crate::{Error, Result};
pub trait Env {
fn persist_dir(&self) -> PathBuf;
fn contains<P: AsRef<Path>>(&self, p: P) -> bool {
p.as_ref().starts_with(self.persist_dir())
}
fn init(&self) -> Result<()> {
let persist_dir = self.persist_dir();
std::fs::create_dir_all(&persist_dir).map_err(Error::CreateNsDirError)?;
let mut made_netns_persist_dir_mount: bool = false;
while let Err(e) = mount(
Some(""),
&persist_dir,
Some("none"),
MsFlags::MS_SHARED | MsFlags::MS_REC,
Some(""),
) {
if e != nix::errno::Errno::EINVAL || made_netns_persist_dir_mount {
return Err(Error::MountError(
format!("--make-rshared {}", persist_dir.display()),
e,
));
}
mount(
Some(&persist_dir),
&persist_dir,
Some("none"),
MsFlags::MS_BIND | MsFlags::MS_REC,
Some(""),
)
.map_err(|e| {
Error::MountError(
format!(
"-rbind {} to {}",
persist_dir.display(),
persist_dir.display()
),
e,
)
})?;
made_netns_persist_dir_mount = true;
}
Ok(())
}
}
#[derive(Copy, Clone, Default, Debug)]
pub struct DefaultEnv;
impl Env for DefaultEnv {
fn persist_dir(&self) -> PathBuf {
PathBuf::from("/var/run/netns")
}
}
#[derive(Debug)]
pub struct NetNs<E: Env = DefaultEnv> {
file: File,
path: PathBuf,
env: Option<E>,
}
impl<E: Env> std::fmt::Display for NetNs<E> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if let Ok(meta) = self.file.metadata() {
write!(
f,
"NetNS {{ fd: {}, dev: {}, ino: {}, path: {} }}",
self.file.as_raw_fd(),
meta.dev(),
meta.ino(),
self.path.display()
)
} else {
write!(
f,
"NetNS {{ fd: {}, path: {} }}",
self.file.as_raw_fd(),
self.path.display()
)
}
}
}
impl<E1: Env, E2: Env> PartialEq<NetNs<E1>> for NetNs<E2> {
fn eq(&self, other: &NetNs<E1>) -> bool {
if self.file.as_raw_fd() == other.file.as_raw_fd() {
return true;
}
let cmp_meta = |f1: &File, f2: &File| -> Option<bool> {
let m1 = match f1.metadata() {
Ok(m) => m,
Err(_) => return None,
};
let m2 = match f2.metadata() {
Ok(m) => m,
Err(_) => return None,
};
Some(m1.dev() == m2.dev() && m1.ino() == m2.ino())
};
cmp_meta(&self.file, &other.file).unwrap_or_else(|| self.path == other.path)
}
}
impl<E: Env> NetNs<E> {
pub fn new_with_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
env.init()?;
let ns_path = env.persist_dir().join(ns_name.as_ref());
let _ = File::create(&ns_path).map_err(Error::CreateNsError)?;
Self::persistent(&ns_path, true).inspect_err(|_e| {
std::fs::remove_file(&ns_path).ok();
})?;
Self::get_from_env(ns_name, env)
}
fn persistent<P: AsRef<Path>>(ns_path: &P, new_thread: bool) -> Result<()> {
if new_thread {
let ns_path_clone = ns_path.as_ref().to_path_buf();
let new_thread: JoinHandle<Result<()>> =
thread::spawn(move || Self::persistent(&ns_path_clone, false));
match new_thread.join() {
Ok(t) => t?,
Err(e) => {
return Err(Error::JoinThreadError(format!("{:?}", e)));
}
};
} else {
unshare(CloneFlags::CLONE_NEWNET).map_err(Error::UnshareError)?;
let src = get_current_thread_netns_path();
mount(
Some(src.as_path()),
ns_path.as_ref(),
Some("none"),
MsFlags::MS_BIND,
Some(""),
)
.map_err(|e| {
Error::MountError(
format!("rbind {} to {}", src.display(), ns_path.as_ref().display()),
e,
)
})?;
}
Ok(())
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn env(&self) -> Option<&E> {
self.env.as_ref()
}
pub fn file(&self) -> &File {
&self.file
}
pub fn enter(&self) -> Result<()> {
setns(&self.file, CloneFlags::CLONE_NEWNET).map_err(Error::SetnsError)
}
pub fn get_from_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
let ns_path = env.persist_dir().join(ns_name.as_ref());
let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
Ok(Self {
file,
path: ns_path,
env: Some(env),
})
}
pub fn remove(self) -> Result<()> {
drop(self.file);
if let Some(env) = &self.env {
if env.contains(&self.path) {
Self::umount_ns(&self.path)?;
}
}
Ok(())
}
fn umount_ns<P: AsRef<Path>>(path: P) -> Result<()> {
let path = path.as_ref();
umount2(path, MntFlags::MNT_DETACH).map_err(|e| Error::UnmountError(path.to_owned(), e))?;
std::fs::remove_file(path).ok();
Ok(())
}
pub fn run<F, T>(&self, f: F) -> Result<T>
where
F: FnOnce(&Self) -> T,
{
let src_ns = get_from_current_thread()?;
if &src_ns == self {
return Ok(f(self));
}
self.enter()?;
let result = f(self);
src_ns.enter()?;
Ok(result)
}
}
impl NetNs {
pub fn new<S: AsRef<str>>(ns_name: S) -> Result<Self> {
Self::new_with_env(ns_name, DefaultEnv)
}
pub fn get<S: AsRef<str>>(ns_name: S) -> Result<Self> {
Self::get_from_env(ns_name, DefaultEnv)
}
pub fn run_in<S, F, T>(ns_name: S, f: F) -> Result<T>
where
S: AsRef<str>,
F: FnOnce(&Self) -> T,
{
let run_ns = Self::get_from_env(ns_name, DefaultEnv)?;
run_ns.run(f)
}
}
pub fn get_from_path<P: AsRef<Path>>(ns_path: P) -> Result<NetNs> {
let ns_path = ns_path.as_ref().to_path_buf();
let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
Ok(NetNs {
file,
path: ns_path,
env: None,
})
}
pub fn get_from_current_thread() -> Result<NetNs> {
let ns_path = get_current_thread_netns_path();
let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
Ok(NetNs {
file,
path: ns_path,
env: None,
})
}
#[inline]
fn get_current_thread_netns_path() -> PathBuf {
PathBuf::from(format!("/proc/self/task/{}/ns/net", gettid()))
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem::ManuallyDrop;
use std::os::unix::io::FromRawFd;
fn make_dummy_netns(fd: i32, path: &str) -> ManuallyDrop<NetNs<DefaultEnv>> {
ManuallyDrop::new(NetNs {
file: unsafe { File::from_raw_fd(fd) },
path: PathBuf::from(path),
env: None,
})
}
#[test]
fn test_netns_display() {
let ns = get_from_current_thread().unwrap();
let print = format!("{}", ns);
assert!(print.contains("dev"));
assert!(print.contains("ino"));
let ns = make_dummy_netns(i32::MAX, "");
let print = format!("{}", *ns);
assert!(!print.contains("dev"));
assert!(!print.contains("ino"));
}
#[test]
fn test_netns_eq() {
let ns1 = get_from_current_thread().unwrap();
let ns2 = get_from_path("/proc/self/ns/net").unwrap();
assert_eq!(ns1, ns2);
let ns1 = make_dummy_netns(i32::MAX, "aaaaaa");
let ns2 = make_dummy_netns(i32::MAX, "bbbbbb");
assert_eq!(*ns1, *ns2);
let ns2 = make_dummy_netns(i32::MAX - 1, "aaaaaa");
assert_eq!(*ns1, *ns2);
}
#[test]
fn test_netns_init() {
let ns = NetNs::new("test_netns_init").unwrap();
assert!(ns.path().exists());
ns.remove().unwrap();
assert!(!Path::new(&DefaultEnv.persist_dir())
.join("test_netns_init")
.exists());
}
struct TestNetNs {
netns: Option<NetNs>,
ns_name: String,
}
impl TestNetNs {
fn new(name: &str) -> Self {
let netns = NetNs::new(name).unwrap();
assert!(netns.path().exists());
Self {
netns: Some(netns),
ns_name: String::from(name),
}
}
fn netns(&self) -> &NetNs {
self.netns.as_ref().unwrap()
}
}
impl Drop for TestNetNs {
fn drop(&mut self) {
let ns_name = self.ns_name.clone();
self.netns.take().unwrap().remove().unwrap();
assert!(!Path::new(&DefaultEnv.persist_dir()).join(ns_name).exists());
}
}
#[test]
fn test_netns_enter() {
let new = TestNetNs::new("test_netns_enter");
let src = get_from_current_thread().unwrap();
assert_ne!(&src, new.netns());
new.netns().enter().unwrap();
let cur = get_from_current_thread().unwrap();
assert_eq!(new.netns(), &cur);
assert_ne!(src, cur);
assert_ne!(&src, new.netns());
}
struct TestEnv;
impl Env for TestEnv {
fn persist_dir(&self) -> PathBuf {
PathBuf::from("/tmp/test_netns")
}
}
#[test]
fn test_netns_with_env() {
let ns_res = NetNs::get_from_env("test_netns_run", TestEnv);
assert!(matches!(ns_res, Err(Error::OpenNsError(_, _))));
let ns = NetNs::new_with_env("test_netns_run", TestEnv).unwrap();
assert!(ns.path().exists());
ns.remove().unwrap();
assert!(!Path::new(&TestEnv.persist_dir())
.join("test_netns_set")
.exists());
}
#[test]
fn test_netns_run() {
let new = TestNetNs::new("test_netns_run");
let src_ns = get_from_current_thread().unwrap();
let ret = new
.netns()
.run(|cur_ns| -> Result<()> {
let cur_thread = get_from_current_thread().unwrap();
assert_eq!(cur_ns, &cur_thread);
assert_eq!(cur_ns, new.netns());
assert_ne!(cur_ns, &src_ns);
Ok(())
})
.unwrap();
assert!(ret.is_ok());
}
}