1use std::fs::File;
7use std::os::unix::fs::MetadataExt;
8use std::os::unix::io::{AsRawFd, IntoRawFd};
9use std::path::{Path, PathBuf};
10use std::thread::{self, JoinHandle};
11
12use nix::mount::{mount, umount2, MntFlags, MsFlags};
13use nix::sched::{setns, unshare, CloneFlags};
14use nix::unistd::gettid;
15
16use crate::{Error, Result};
17
18pub trait Env {
20 fn persist_dir(&self) -> PathBuf;
22
23 fn contains<P: AsRef<Path>>(&self, p: P) -> bool {
25 p.as_ref().starts_with(self.persist_dir())
26 }
27
28 fn init(&self) -> Result<()> {
30 let persist_dir = self.persist_dir();
34 std::fs::create_dir_all(&persist_dir).map_err(Error::CreateNsDirError)?;
35
36 let mut made_netns_persist_dir_mount: bool = false;
40 while let Err(e) = mount(
41 Some(""),
42 &persist_dir,
43 Some("none"),
44 MsFlags::MS_SHARED | MsFlags::MS_REC,
45 Some(""),
46 ) {
47 if e != nix::errno::Errno::EINVAL || made_netns_persist_dir_mount {
49 return Err(Error::MountError(
50 format!("--make-rshared {}", persist_dir.display()),
51 e,
52 ));
53 }
54 mount(
57 Some(&persist_dir),
58 &persist_dir,
59 Some("none"),
60 MsFlags::MS_BIND | MsFlags::MS_REC,
61 Some(""),
62 )
63 .map_err(|e| {
64 Error::MountError(
65 format!(
66 "-rbind {} to {}",
67 persist_dir.display(),
68 persist_dir.display()
69 ),
70 e,
71 )
72 })?;
73 made_netns_persist_dir_mount = true;
74 }
75
76 Ok(())
77 }
78}
79
80#[derive(Copy, Clone, Default, Debug)]
85pub struct DefaultEnv;
86
87impl Env for DefaultEnv {
88 fn persist_dir(&self) -> PathBuf {
89 PathBuf::from("/var/run/netns")
90 }
91}
92
93#[derive(Debug)]
97pub struct NetNs<E: Env = DefaultEnv> {
98 file: File,
99 path: PathBuf,
100 env: Option<E>,
101}
102
103impl<E: Env> std::fmt::Display for NetNs<E> {
104 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
105 if let Ok(meta) = self.file.metadata() {
106 write!(
107 f,
108 "NetNS {{ fd: {}, dev: {}, ino: {}, path: {} }}",
109 self.file.as_raw_fd(),
110 meta.dev(),
111 meta.ino(),
112 self.path.display()
113 )
114 } else {
115 write!(
116 f,
117 "NetNS {{ fd: {}, path: {} }}",
118 self.file.as_raw_fd(),
119 self.path.display()
120 )
121 }
122 }
123}
124
125impl<E1: Env, E2: Env> PartialEq<NetNs<E1>> for NetNs<E2> {
126 fn eq(&self, other: &NetNs<E1>) -> bool {
127 if self.file.as_raw_fd() == other.file.as_raw_fd() {
128 return true;
129 }
130 let cmp_meta = |f1: &File, f2: &File| -> Option<bool> {
131 let m1 = match f1.metadata() {
132 Ok(m) => m,
133 Err(_) => return None,
134 };
135 let m2 = match f2.metadata() {
136 Ok(m) => m,
137 Err(_) => return None,
138 };
139 Some(m1.dev() == m2.dev() && m1.ino() == m2.ino())
140 };
141 cmp_meta(&self.file, &other.file).unwrap_or_else(|| self.path == other.path)
142 }
143}
144
145impl<E: Env> NetNs<E> {
146 pub fn new_with_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
150 env.init()?;
151
152 let ns_path = env.persist_dir().join(ns_name.as_ref());
154 let _ = File::create(&ns_path).map_err(Error::CreateNsError)?;
155 Self::persistent(&ns_path, true).map_err(|e| {
156 std::fs::remove_file(&ns_path).ok();
159 e
160 })?;
161 Self::get_from_env(ns_name, env)
162 }
163
164 fn persistent<P: AsRef<Path>>(ns_path: &P, new_thread: bool) -> Result<()> {
165 if new_thread {
166 let ns_path_clone = ns_path.as_ref().to_path_buf();
167 let new_thread: JoinHandle<Result<()>> =
168 thread::spawn(move || Self::persistent(&ns_path_clone, false));
169 match new_thread.join() {
170 Ok(t) => {
171 if let Err(e) = t {
172 return Err(e);
173 }
174 }
175 Err(e) => {
176 return Err(Error::JoinThreadError(format!("{:?}", e)));
177 }
178 };
179 } else {
180 unshare(CloneFlags::CLONE_NEWNET).map_err(Error::UnshareError)?;
182 let src = get_current_thread_netns_path();
185 mount(
186 Some(src.as_path()),
187 ns_path.as_ref(),
188 Some("none"),
189 MsFlags::MS_BIND,
190 Some(""),
191 )
192 .map_err(|e| {
193 Error::MountError(
194 format!("rbind {} to {}", src.display(), ns_path.as_ref().display()),
195 e,
196 )
197 })?;
198 }
199
200 Ok(())
201 }
202
203 pub fn path(&self) -> &Path {
205 &self.path
206 }
207
208 pub fn env(&self) -> Option<&E> {
210 self.env.as_ref()
211 }
212
213 pub fn file(&self) -> &File {
215 &self.file
216 }
217
218 pub fn enter(&self) -> Result<()> {
222 setns(self.file.as_raw_fd(), CloneFlags::CLONE_NEWNET).map_err(Error::SetnsError)
223 }
224
225 pub fn get_from_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
227 let ns_path = env.persist_dir().join(ns_name.as_ref());
228 let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
229
230 Ok(Self {
231 file,
232 path: ns_path,
233 env: Some(env),
234 })
235 }
236
237 pub fn remove(self) -> Result<()> {
241 nix::unistd::close(self.file.into_raw_fd()).map_err(Error::CloseNsError)?;
243 if let Some(env) = &self.env {
245 if env.contains(&self.path) {
246 Self::umount_ns(&self.path)?;
247 }
248 }
249 Ok(())
250 }
251
252 fn umount_ns<P: AsRef<Path>>(path: P) -> Result<()> {
253 let path = path.as_ref();
254 umount2(path, MntFlags::MNT_DETACH).map_err(|e| Error::UnmountError(path.to_owned(), e))?;
255 std::fs::remove_file(path).ok();
257 Ok(())
258 }
259
260 pub fn run<F, T>(&self, f: F) -> Result<T>
264 where
265 F: FnOnce(&Self) -> T,
266 {
267 let src_ns = get_from_current_thread()?;
269
270 if &src_ns == self {
272 return Ok(f(self));
273 }
274 self.enter()?;
276
277 let result = f(self);
278 src_ns.enter()?;
280
281 Ok(result)
282 }
283}
284
285impl NetNs {
286 pub fn new<S: AsRef<str>>(ns_name: S) -> Result<Self> {
297 Self::new_with_env(ns_name, DefaultEnv)
298 }
299
300 pub fn get<S: AsRef<str>>(ns_name: S) -> Result<Self> {
302 Self::get_from_env(ns_name, DefaultEnv)
303 }
304
305 pub fn run_in<S, F, T>(ns_name: S, f: F) -> Result<T>
309 where
310 S: AsRef<str>,
311 F: FnOnce(&Self) -> T,
312 {
313 let run_ns = Self::get_from_env(ns_name, DefaultEnv)?;
315 run_ns.run(f)
316 }
317}
318
319pub fn get_from_path<P: AsRef<Path>>(ns_path: P) -> Result<NetNs> {
321 let ns_path = ns_path.as_ref().to_path_buf();
322 let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
323
324 Ok(NetNs {
325 file,
326 path: ns_path,
327 env: None,
328 })
329}
330
331pub fn get_from_current_thread() -> Result<NetNs> {
333 let ns_path = get_current_thread_netns_path();
334 let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
335
336 Ok(NetNs {
337 file,
338 path: ns_path,
339 env: None,
340 })
341}
342
343#[inline]
344fn get_current_thread_netns_path() -> PathBuf {
345 PathBuf::from(format!("/proc/self/task/{}/ns/net", gettid()))
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use std::os::unix::io::FromRawFd;
352
353 #[test]
354 fn test_netns_display() {
355 let ns = get_from_current_thread().unwrap();
356 let print = format!("{}", ns);
357 assert!(print.contains("dev"));
358 assert!(print.contains("ino"));
359
360 let ns: NetNs<DefaultEnv> = NetNs {
361 file: unsafe { File::from_raw_fd(i32::MAX) },
362 path: PathBuf::from(""),
363 env: None,
364 };
365 let print = format!("{}", ns);
366 assert!(!print.contains("dev"));
367 assert!(!print.contains("ino"));
368 }
369
370 #[test]
371 fn test_netns_eq() {
372 let ns1 = get_from_current_thread().unwrap();
373 let ns2 = get_from_path("/proc/self/ns/net").unwrap();
374 assert_eq!(ns1, ns2);
375
376 let ns1: NetNs<DefaultEnv> = NetNs {
377 file: unsafe { File::from_raw_fd(i32::MAX) },
378 path: PathBuf::from("aaaaaa"),
379 env: None,
380 };
381 let ns2: NetNs<DefaultEnv> = NetNs {
382 file: unsafe { File::from_raw_fd(i32::MAX) },
383 path: PathBuf::from("bbbbbb"),
384 env: None,
385 };
386 assert_eq!(ns1, ns2);
387
388 let ns2: NetNs<DefaultEnv> = NetNs {
389 file: unsafe { File::from_raw_fd(i32::MAX - 1) },
390 path: PathBuf::from("aaaaaa"),
391 env: None,
392 };
393 assert_eq!(ns1, ns2);
394 }
395
396 #[test]
397 fn test_netns_init() {
398 let ns = NetNs::new("test_netns_init").unwrap();
399 assert!(ns.path().exists());
400 ns.remove().unwrap();
401 assert!(!Path::new(&DefaultEnv.persist_dir())
402 .join("test_netns_init")
403 .exists());
404 }
405
406 struct TestNetNs {
407 netns: Option<NetNs>,
408 ns_name: String,
409 }
410
411 impl TestNetNs {
412 fn new(name: &str) -> Self {
413 let netns = NetNs::new(name).unwrap();
414 assert!(netns.path().exists());
415 Self {
416 netns: Some(netns),
417 ns_name: String::from(name),
418 }
419 }
420
421 fn netns(&self) -> &NetNs {
422 self.netns.as_ref().unwrap()
423 }
424 }
425
426 impl Drop for TestNetNs {
427 fn drop(&mut self) {
428 let ns_name = self.ns_name.clone();
429 self.netns.take().unwrap().remove().unwrap();
430 assert!(!Path::new(&DefaultEnv.persist_dir()).join(ns_name).exists());
431 }
432 }
433
434 #[test]
435 fn test_netns_enter() {
436 let new = TestNetNs::new("test_netns_enter");
437
438 let src = get_from_current_thread().unwrap();
439 assert_ne!(&src, new.netns());
440
441 new.netns().enter().unwrap();
442
443 let cur = get_from_current_thread().unwrap();
444
445 assert_eq!(new.netns(), &cur);
446 assert_ne!(src, cur);
447 assert_ne!(&src, new.netns());
448 }
449
450 struct TestEnv;
451 impl Env for TestEnv {
452 fn persist_dir(&self) -> PathBuf {
453 PathBuf::from("/tmp/test_netns")
454 }
455 }
456
457 #[test]
458 fn test_netns_with_env() {
459 let ns_res = NetNs::get_from_env("test_netns_run", TestEnv);
460 assert!(matches!(ns_res, Err(Error::OpenNsError(_, _))));
461
462 let ns = NetNs::new_with_env("test_netns_run", TestEnv).unwrap();
463 assert!(ns.path().exists());
464
465 ns.remove().unwrap();
466 assert!(!Path::new(&TestEnv.persist_dir())
467 .join("test_netns_set")
468 .exists());
469 }
470
471 #[test]
472 fn test_netns_run() {
473 let new = TestNetNs::new("test_netns_run");
474
475 let src_ns = get_from_current_thread().unwrap();
476
477 let ret = new
478 .netns()
479 .run(|cur_ns| -> Result<()> {
480 let cur_thread = get_from_current_thread().unwrap();
481 assert_eq!(cur_ns, &cur_thread);
482 assert_eq!(cur_ns, new.netns());
484 assert_ne!(cur_ns, &src_ns);
485
486 Ok(())
487 })
488 .unwrap();
489 assert!(matches!(ret, Ok(_)));
490 }
491}