use std::io;
use std::os::fd::{AsFd, AsRawFd};
use std::path::{Path, PathBuf};
use tokio::sync::oneshot;
use crate::dynch::{DynCh, DynRequestSender};
use crate::namespace::helpers::current_netns;
use crate::network::RuntimeFactory;
const NETNS_RUN_DIR: &str = "/run/netns";
const PROC_SELF_FD: &str = "/proc/self/fd";
const PROC_SELF_TASK: &str = "/proc/self/task";
pub(crate) mod helpers {
use std::{
fs, io,
os::fd::{AsRawFd as _, BorrowedFd},
};
use nix::mount::{MsFlags, mount};
use nix::sched::CloneFlags;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub struct NetnsId {
pub inode: u64,
}
impl NetnsId {
fn from_metadata(meta: &fs::Metadata) -> Self {
use std::os::unix::fs::MetadataExt;
NetnsId { inode: meta.ino() }
}
}
pub fn current_netns() -> io::Result<NetnsId> {
let thread_id = unsafe { nix::libc::gettid() };
let path = format!("{}/{thread_id}/ns/net", super::PROC_SELF_TASK);
let meta = fs::metadata(&path)?;
let ns = NetnsId::from_metadata(&meta);
tracing::debug!(ino = ns.inode, ?path, "current netns");
Ok(ns)
}
pub fn netns_from_fd(fd: BorrowedFd<'_>) -> io::Result<NetnsId> {
let path = format!("{}/{}", super::PROC_SELF_FD, fd.as_raw_fd());
let meta = fs::metadata(&path)?;
let ns = NetnsId::from_metadata(&meta);
tracing::debug!(
ino = ns.inode,
fd = fd.as_raw_fd(),
path = %path,
"netns from fd"
);
Ok(ns)
}
pub fn setns_verified(fd: BorrowedFd<'_>) -> io::Result<(NetnsId, NetnsId)> {
let _span = tracing::debug_span!("setns_verified", fd = fd.as_raw_fd()).entered();
let before = current_netns()?;
let target = netns_from_fd(fd)?;
tracing::debug!(before_ino = before.inode, target_ino = target.inode, "setns target");
nix::sched::setns(fd, CloneFlags::CLONE_NEWNET)?;
let after = current_netns()?;
tracing::debug!(after_ino = after.inode, "after setns");
if after != target {
return Err(io::Error::other(format!(
"setns verification failed: after={after:?}, target={target:?}"
)));
}
if before == after {
tracing::debug!("setns was a no-op (already in target namespace)");
} else {
tracing::debug!("namespace switch successful");
}
Ok((before, after))
}
pub fn setup_mount_namespace() -> io::Result<()> {
nix::sched::unshare(CloneFlags::CLONE_NEWNS)
.map_err(|e| io::Error::other(format!("unshare(CLONE_NEWNS) failed: {}", e)))?;
tracing::debug!("created new mount namespace");
mount(Some("proc"), "/proc", Some("proc"), MsFlags::empty(), None::<&str>)
.map_err(|e| io::Error::other(format!("mount proc failed: {}", e)))?;
tracing::debug!("remounted /proc for namespace-specific view");
Ok(())
}
pub struct NetnsGuard {
expected: NetnsId,
}
impl NetnsGuard {
pub fn new(expected: NetnsId) -> io::Result<Self> {
let current = current_netns()?;
if current != expected {
return Err(io::Error::other("thread not in expected netns"));
}
Ok(Self { expected })
}
}
impl Drop for NetnsGuard {
fn drop(&mut self) {
if let Ok(current) = current_netns() {
debug_assert_eq!(
current, self.expected,
"network namespace changed while guard was alive"
);
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("io error: {0}")]
Io(#[from] io::Error),
#[error("rtnetlink error: {0}")]
RtNetlink(#[from] rtnetlink::Error),
#[error("nix error: {0}")]
Nix(#[from] nix::Error),
#[error("thread error: {0}")]
Thread(String),
#[error("failed to send task, channel closed")]
ChannelClosed,
#[error("failed to receive task result: {0}")]
RecvError(#[from] oneshot::error::RecvError),
}
#[derive(Debug, thiserror::Error)]
pub enum TaskError<E> {
#[error("io error: {0}")]
Io(#[from] io::Error),
#[error("nix error: {0}")]
Nix(#[from] nix::Error),
#[error("task error: {0}")]
Task(E),
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub struct NetworkNamespaceInner {
pub name: String,
pub file: std::fs::File,
}
trait TryClone: Sized {
type Error;
fn try_clone(&self) -> std::result::Result<Self, Self::Error>;
}
impl TryClone for NetworkNamespaceInner {
type Error = io::Error;
fn try_clone(&self) -> std::result::Result<Self, Self::Error> {
let file = self.file.try_clone()?;
Ok(Self { name: self.name.clone(), file })
}
}
impl NetworkNamespaceInner {
pub fn spawn<Ctx: 'static>(
self,
runtime_factory: RuntimeFactory,
make_ctx: impl FnOnce() -> Ctx + Send + 'static,
) -> (std::thread::JoinHandle<Result<()>>, DynRequestSender<Ctx>) {
let (tx, mut rx) = DynCh::<Ctx>::channel(8);
let handle = std::thread::spawn(move || {
let fd = self.file.as_fd();
let name = &self.name;
let _span = tracing::info_span!("spawn_namespace", ?name).entered();
let (_before, after) = helpers::setns_verified(fd)?;
let _guard = helpers::NetnsGuard::new(after)?;
helpers::setup_mount_namespace()?;
let rt = runtime_factory();
tracing::debug!("started runtime");
drop(_span);
rt.block_on(async move {
let mut ctx = make_ctx();
while let Some(req) = rx.recv().await {
let (task, tx) = req.into_parts();
let _span = tracing::info_span!("namespace_job", ?fd, name).entered();
debug_assert_eq!(after, current_netns().expect("to check current namespace"));
let res = task(&mut ctx).await;
if tx.send(res).is_err() {
tracing::error!("failed to send back task response, rx dropped");
}
}
});
Ok(())
});
(handle, tx)
}
}
#[derive(Debug)]
pub struct NetworkNamespace<Ctx = ()> {
pub inner: NetworkNamespaceInner,
pub task_sender: DynRequestSender<Ctx>,
pub _receiver_task: std::thread::JoinHandle<Result<()>>,
}
impl NetworkNamespace {
pub async fn new<Ctx: 'static>(
name: impl Into<String>,
runtime_factory: RuntimeFactory,
make_ctx: impl FnOnce() -> Ctx + Send + 'static,
) -> Result<NetworkNamespace<Ctx>> {
let name = name.into();
rtnetlink::NetworkNamespace::add(name.clone()).await?;
let path = Self::path(&name);
let file = tokio::fs::File::open(path).await?.into_std().await;
let inner = NetworkNamespaceInner { name, file };
let (_receiver_task, task_sender) = inner.try_clone()?.spawn(runtime_factory, make_ctx);
Ok(NetworkNamespace::<Ctx> { inner, task_sender, _receiver_task })
}
}
impl<Ctx> NetworkNamespace<Ctx> {
pub fn path(name: &str) -> PathBuf {
Path::new(NETNS_RUN_DIR).join(name)
}
pub fn fd(&self) -> i32 {
self.inner.file.as_fd().as_raw_fd()
}
}
impl<Ctx> Drop for NetworkNamespace<Ctx> {
fn drop(&mut self) {
let namespace = self.inner.name.clone();
let task = async {
if let Err(e) = rtnetlink::NetworkNamespace::del(namespace.clone()).await {
tracing::error!(?e, ?namespace, "failed to delete network namespace");
}
};
if tokio::runtime::Handle::try_current().is_ok() {
tokio::task::block_in_place(|| {
let handle = tokio::runtime::Handle::current();
handle.block_on(task)
});
return;
}
let rt =
tokio::runtime::Runtime::new().expect("failed to build temporary runtime for cleanup");
rt.block_on(task)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dynch::DynFuture;
use crate::network::default_runtime_factory;
const TCP_SLOW_START_AFTER_IDLE: &str = "/proc/sys/net/ipv4/tcp_slow_start_after_idle";
#[tokio::test(flavor = "multi_thread")]
async fn mount_namespace_isolates_proc() {
let ns1 = NetworkNamespace::new("test-ns-mount-1", default_runtime_factory(), || ())
.await
.unwrap();
let ns2 = NetworkNamespace::new("test-ns-mount-2", default_runtime_factory(), || ())
.await
.unwrap();
let proc_mounted_ns1: bool = ns1
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, bool> {
Box::pin(async { std::path::Path::new("/proc/self/ns/net").exists() })
})
.await
.unwrap()
.receive()
.await
.unwrap();
assert!(proc_mounted_ns1, "/proc should be mounted in namespace 1");
let proc_mounted_ns2: bool = ns2
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, bool> {
Box::pin(async { std::path::Path::new("/proc/self/ns/net").exists() })
})
.await
.unwrap()
.receive()
.await
.unwrap();
assert!(proc_mounted_ns2, "/proc should be mounted in namespace 2");
}
#[tokio::test(flavor = "multi_thread")]
async fn sysctl_values_are_namespace_specific() {
let ns1 = NetworkNamespace::new("test-ns-sysctl-1", default_runtime_factory(), || ())
.await
.unwrap();
let ns2 = NetworkNamespace::new("test-ns-sysctl-2", default_runtime_factory(), || ())
.await
.unwrap();
let write_result_ns1: std::io::Result<()> = ns1
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, std::io::Result<()>> {
Box::pin(async { std::fs::write(TCP_SLOW_START_AFTER_IDLE, "0") })
})
.await
.unwrap()
.receive()
.await
.unwrap();
write_result_ns1.expect("should write sysctl in ns1");
let write_result_ns2: std::io::Result<()> = ns2
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, std::io::Result<()>> {
Box::pin(async { std::fs::write(TCP_SLOW_START_AFTER_IDLE, "1") })
})
.await
.unwrap()
.receive()
.await
.unwrap();
write_result_ns2.expect("should write sysctl in ns2");
let value_ns1: String = ns1
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, String> {
Box::pin(async {
std::fs::read_to_string(TCP_SLOW_START_AFTER_IDLE)
.map(|s| s.trim().to_string())
.unwrap_or_else(|_| "error".to_string())
})
})
.await
.unwrap()
.receive()
.await
.unwrap();
let value_ns2: String = ns2
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, String> {
Box::pin(async {
std::fs::read_to_string(TCP_SLOW_START_AFTER_IDLE)
.map(|s| s.trim().to_string())
.unwrap_or_else(|_| "error".to_string())
})
})
.await
.unwrap()
.receive()
.await
.unwrap();
assert_eq!(value_ns1, "0", "ns1 should have tcp_slow_start_after_idle=0");
assert_eq!(value_ns2, "1", "ns2 should have tcp_slow_start_after_idle=1");
assert_ne!(value_ns1, value_ns2, "sysctls should be isolated between namespaces");
}
#[tokio::test(flavor = "multi_thread")]
async fn namespace_has_isolated_network_identity() {
let ns = NetworkNamespace::new("test-ns-identity", default_runtime_factory(), || ())
.await
.unwrap();
let ns_inode_inside: u64 = ns
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, u64> {
Box::pin(async { helpers::current_netns().map(|id| id.inode).unwrap_or(0) })
})
.await
.unwrap()
.receive()
.await
.unwrap();
let host_inode = helpers::current_netns().map(|id| id.inode).unwrap_or(0);
assert_ne!(ns_inode_inside, 0, "should get valid inode inside namespace");
assert_ne!(host_inode, 0, "should get valid host inode");
assert_ne!(ns_inode_inside, host_inode, "namespace inode should differ from host");
}
}