1use std::fs::File;
7use std::os::unix::fs::MetadataExt;
8use std::os::unix::io::AsRawFd;
9use std::path::{Path, PathBuf};
10use std::thread::{self, JoinHandle};
11
12use nix::mount::{mount, umount2, MntFlags, MsFlags};
13use nix::sched::{setns, unshare, CloneFlags};
14use nix::unistd::gettid;
15
16use crate::{Error, Result};
17
18pub trait Env {
20 fn persist_dir(&self) -> PathBuf;
22
23 fn contains<P: AsRef<Path>>(&self, p: P) -> bool {
25 p.as_ref().starts_with(self.persist_dir())
26 }
27
28 fn init(&self) -> Result<()> {
30 let persist_dir = self.persist_dir();
34 std::fs::create_dir_all(&persist_dir).map_err(Error::CreateNsDirError)?;
35
36 let mut made_netns_persist_dir_mount: bool = false;
40 while let Err(e) = mount(
41 Some(""),
42 &persist_dir,
43 Some("none"),
44 MsFlags::MS_SHARED | MsFlags::MS_REC,
45 Some(""),
46 ) {
47 if e != nix::errno::Errno::EINVAL || made_netns_persist_dir_mount {
49 return Err(Error::MountError(
50 format!("--make-rshared {}", persist_dir.display()),
51 e,
52 ));
53 }
54 mount(
57 Some(&persist_dir),
58 &persist_dir,
59 Some("none"),
60 MsFlags::MS_BIND | MsFlags::MS_REC,
61 Some(""),
62 )
63 .map_err(|e| {
64 Error::MountError(
65 format!(
66 "-rbind {} to {}",
67 persist_dir.display(),
68 persist_dir.display()
69 ),
70 e,
71 )
72 })?;
73 made_netns_persist_dir_mount = true;
74 }
75
76 Ok(())
77 }
78}
79
80#[derive(Copy, Clone, Default, Debug)]
85pub struct DefaultEnv;
86
87impl Env for DefaultEnv {
88 fn persist_dir(&self) -> PathBuf {
89 PathBuf::from("/var/run/netns")
90 }
91}
92
93#[derive(Debug)]
97pub struct NetNs<E: Env = DefaultEnv> {
98 file: File,
99 path: PathBuf,
100 env: Option<E>,
101}
102
103impl<E: Env> std::fmt::Display for NetNs<E> {
104 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
105 if let Ok(meta) = self.file.metadata() {
106 write!(
107 f,
108 "NetNS {{ fd: {}, dev: {}, ino: {}, path: {} }}",
109 self.file.as_raw_fd(),
110 meta.dev(),
111 meta.ino(),
112 self.path.display()
113 )
114 } else {
115 write!(
116 f,
117 "NetNS {{ fd: {}, path: {} }}",
118 self.file.as_raw_fd(),
119 self.path.display()
120 )
121 }
122 }
123}
124
125impl<E1: Env, E2: Env> PartialEq<NetNs<E1>> for NetNs<E2> {
126 fn eq(&self, other: &NetNs<E1>) -> bool {
127 if self.file.as_raw_fd() == other.file.as_raw_fd() {
128 return true;
129 }
130 let cmp_meta = |f1: &File, f2: &File| -> Option<bool> {
131 let m1 = match f1.metadata() {
132 Ok(m) => m,
133 Err(_) => return None,
134 };
135 let m2 = match f2.metadata() {
136 Ok(m) => m,
137 Err(_) => return None,
138 };
139 Some(m1.dev() == m2.dev() && m1.ino() == m2.ino())
140 };
141 cmp_meta(&self.file, &other.file).unwrap_or_else(|| self.path == other.path)
142 }
143}
144
145impl<E: Env> NetNs<E> {
146 pub fn new_with_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
150 env.init()?;
151
152 let ns_path = env.persist_dir().join(ns_name.as_ref());
154 let _ = File::create(&ns_path).map_err(Error::CreateNsError)?;
155 Self::persistent(&ns_path, true).inspect_err(|_e| {
156 std::fs::remove_file(&ns_path).ok();
159 })?;
160 Self::get_from_env(ns_name, env)
161 }
162
163 fn persistent<P: AsRef<Path>>(ns_path: &P, new_thread: bool) -> Result<()> {
164 if new_thread {
165 let ns_path_clone = ns_path.as_ref().to_path_buf();
166 let new_thread: JoinHandle<Result<()>> =
167 thread::spawn(move || Self::persistent(&ns_path_clone, false));
168 match new_thread.join() {
169 Ok(t) => t?,
170 Err(e) => {
171 return Err(Error::JoinThreadError(format!("{:?}", e)));
172 }
173 };
174 } else {
175 unshare(CloneFlags::CLONE_NEWNET).map_err(Error::UnshareError)?;
177 let src = get_current_thread_netns_path();
180 mount(
181 Some(src.as_path()),
182 ns_path.as_ref(),
183 Some("none"),
184 MsFlags::MS_BIND,
185 Some(""),
186 )
187 .map_err(|e| {
188 Error::MountError(
189 format!("rbind {} to {}", src.display(), ns_path.as_ref().display()),
190 e,
191 )
192 })?;
193 }
194
195 Ok(())
196 }
197
198 pub fn path(&self) -> &Path {
200 &self.path
201 }
202
203 pub fn env(&self) -> Option<&E> {
205 self.env.as_ref()
206 }
207
208 pub fn file(&self) -> &File {
210 &self.file
211 }
212
213 pub fn enter(&self) -> Result<()> {
217 setns(&self.file, CloneFlags::CLONE_NEWNET).map_err(Error::SetnsError)
218 }
219
220 pub fn get_from_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
222 let ns_path = env.persist_dir().join(ns_name.as_ref());
223 let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
224
225 Ok(Self {
226 file,
227 path: ns_path,
228 env: Some(env),
229 })
230 }
231
232 pub fn remove(self) -> Result<()> {
236 drop(self.file);
238 if let Some(env) = &self.env {
240 if env.contains(&self.path) {
241 Self::umount_ns(&self.path)?;
242 }
243 }
244 Ok(())
245 }
246
247 fn umount_ns<P: AsRef<Path>>(path: P) -> Result<()> {
248 let path = path.as_ref();
249 umount2(path, MntFlags::MNT_DETACH).map_err(|e| Error::UnmountError(path.to_owned(), e))?;
250 std::fs::remove_file(path).ok();
252 Ok(())
253 }
254
255 pub fn run<F, T>(&self, f: F) -> Result<T>
259 where
260 F: FnOnce(&Self) -> T,
261 {
262 let src_ns = get_from_current_thread()?;
264
265 if &src_ns == self {
267 return Ok(f(self));
268 }
269 self.enter()?;
271
272 let result = f(self);
273 src_ns.enter()?;
275
276 Ok(result)
277 }
278}
279
280impl NetNs {
281 pub fn new<S: AsRef<str>>(ns_name: S) -> Result<Self> {
292 Self::new_with_env(ns_name, DefaultEnv)
293 }
294
295 pub fn get<S: AsRef<str>>(ns_name: S) -> Result<Self> {
297 Self::get_from_env(ns_name, DefaultEnv)
298 }
299
300 pub fn run_in<S, F, T>(ns_name: S, f: F) -> Result<T>
304 where
305 S: AsRef<str>,
306 F: FnOnce(&Self) -> T,
307 {
308 let run_ns = Self::get_from_env(ns_name, DefaultEnv)?;
310 run_ns.run(f)
311 }
312}
313
314pub fn get_from_path<P: AsRef<Path>>(ns_path: P) -> Result<NetNs> {
316 let ns_path = ns_path.as_ref().to_path_buf();
317 let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
318
319 Ok(NetNs {
320 file,
321 path: ns_path,
322 env: None,
323 })
324}
325
326pub fn get_from_current_thread() -> Result<NetNs> {
328 let ns_path = get_current_thread_netns_path();
329 let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
330
331 Ok(NetNs {
332 file,
333 path: ns_path,
334 env: None,
335 })
336}
337
338#[inline]
339fn get_current_thread_netns_path() -> PathBuf {
340 PathBuf::from(format!("/proc/self/task/{}/ns/net", gettid()))
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use std::mem::ManuallyDrop;
347 use std::os::unix::io::FromRawFd;
348
349 fn make_dummy_netns(fd: i32, path: &str) -> ManuallyDrop<NetNs<DefaultEnv>> {
350 ManuallyDrop::new(NetNs {
351 file: unsafe { File::from_raw_fd(fd) },
352 path: PathBuf::from(path),
353 env: None,
354 })
355 }
356
357 #[test]
358 fn test_netns_display() {
359 let ns = get_from_current_thread().unwrap();
360 let print = format!("{}", ns);
361 assert!(print.contains("dev"));
362 assert!(print.contains("ino"));
363
364 let ns = make_dummy_netns(i32::MAX, "");
365 let print = format!("{}", *ns);
366 assert!(!print.contains("dev"));
367 assert!(!print.contains("ino"));
368 }
369
370 #[test]
371 fn test_netns_eq() {
372 let ns1 = get_from_current_thread().unwrap();
373 let ns2 = get_from_path("/proc/self/ns/net").unwrap();
374 assert_eq!(ns1, ns2);
375
376 let ns1 = make_dummy_netns(i32::MAX, "aaaaaa");
377 let ns2 = make_dummy_netns(i32::MAX, "bbbbbb");
378 assert_eq!(*ns1, *ns2);
379
380 let ns2 = make_dummy_netns(i32::MAX - 1, "aaaaaa");
381 assert_eq!(*ns1, *ns2);
382 }
383
384 #[test]
385 fn test_netns_init() {
386 let ns = NetNs::new("test_netns_init").unwrap();
387 assert!(ns.path().exists());
388 ns.remove().unwrap();
389 assert!(!Path::new(&DefaultEnv.persist_dir())
390 .join("test_netns_init")
391 .exists());
392 }
393
394 struct TestNetNs {
395 netns: Option<NetNs>,
396 ns_name: String,
397 }
398
399 impl TestNetNs {
400 fn new(name: &str) -> Self {
401 let netns = NetNs::new(name).unwrap();
402 assert!(netns.path().exists());
403 Self {
404 netns: Some(netns),
405 ns_name: String::from(name),
406 }
407 }
408
409 fn netns(&self) -> &NetNs {
410 self.netns.as_ref().unwrap()
411 }
412 }
413
414 impl Drop for TestNetNs {
415 fn drop(&mut self) {
416 let ns_name = self.ns_name.clone();
417 self.netns.take().unwrap().remove().unwrap();
418 assert!(!Path::new(&DefaultEnv.persist_dir()).join(ns_name).exists());
419 }
420 }
421
422 #[test]
423 fn test_netns_enter() {
424 let new = TestNetNs::new("test_netns_enter");
425
426 let src = get_from_current_thread().unwrap();
427 assert_ne!(&src, new.netns());
428
429 new.netns().enter().unwrap();
430
431 let cur = get_from_current_thread().unwrap();
432
433 assert_eq!(new.netns(), &cur);
434 assert_ne!(src, cur);
435 assert_ne!(&src, new.netns());
436 }
437
438 struct TestEnv;
439 impl Env for TestEnv {
440 fn persist_dir(&self) -> PathBuf {
441 PathBuf::from("/tmp/test_netns")
442 }
443 }
444
445 #[test]
446 fn test_netns_with_env() {
447 let ns_res = NetNs::get_from_env("test_netns_run", TestEnv);
448 assert!(matches!(ns_res, Err(Error::OpenNsError(_, _))));
449
450 let ns = NetNs::new_with_env("test_netns_run", TestEnv).unwrap();
451 assert!(ns.path().exists());
452
453 ns.remove().unwrap();
454 assert!(!Path::new(&TestEnv.persist_dir())
455 .join("test_netns_set")
456 .exists());
457 }
458
459 #[test]
460 fn test_netns_run() {
461 let new = TestNetNs::new("test_netns_run");
462
463 let src_ns = get_from_current_thread().unwrap();
464
465 let ret = new
466 .netns()
467 .run(|cur_ns| -> Result<()> {
468 let cur_thread = get_from_current_thread().unwrap();
469 assert_eq!(cur_ns, &cur_thread);
470 assert_eq!(cur_ns, new.netns());
472 assert_ne!(cur_ns, &src_ns);
473
474 Ok(())
475 })
476 .unwrap();
477 assert!(ret.is_ok());
478 }
479}