netns_rs/
netns.rs

1// Copyright 2022 Alibaba Cloud. All Rights Reserved.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5
6use std::fs::File;
7use std::os::unix::fs::MetadataExt;
8use std::os::unix::io::{AsRawFd, IntoRawFd};
9use std::path::{Path, PathBuf};
10use std::thread::{self, JoinHandle};
11
12use nix::mount::{mount, umount2, MntFlags, MsFlags};
13use nix::sched::{setns, unshare, CloneFlags};
14use nix::unistd::gettid;
15
16use crate::{Error, Result};
17
18/// Defines a NetNs environment behavior.
19pub trait Env {
20    /// The persist directory of the NetNs environment.
21    fn persist_dir(&self) -> PathBuf;
22
23    /// Returns `true` if the given path is in this Env.
24    fn contains<P: AsRef<Path>>(&self, p: P) -> bool {
25        p.as_ref().starts_with(self.persist_dir())
26    }
27
28    /// Initialize the environment.
29    fn init(&self) -> Result<()> {
30        // Create the directory for mounting network namespaces.
31        // This needs to be a shared mountpoint in case it is mounted in to
32        // other namespaces (containers)
33        let persist_dir = self.persist_dir();
34        std::fs::create_dir_all(&persist_dir).map_err(Error::CreateNsDirError)?;
35
36        // Remount the namespace directory shared. This will fail if it is not
37        // already a mountpoint, so bind-mount it on to itself to "upgrade" it
38        // to a mountpoint.
39        let mut made_netns_persist_dir_mount: bool = false;
40        while let Err(e) = mount(
41            Some(""),
42            &persist_dir,
43            Some("none"),
44            MsFlags::MS_SHARED | MsFlags::MS_REC,
45            Some(""),
46        ) {
47            // Fail unless we need to make the mount point
48            if e != nix::errno::Errno::EINVAL || made_netns_persist_dir_mount {
49                return Err(Error::MountError(
50                    format!("--make-rshared {}", persist_dir.display()),
51                    e,
52                ));
53            }
54            // Recursively remount /var/persist/netns on itself. The recursive flag is
55            // so that any existing netns bindmounts are carried over.
56            mount(
57                Some(&persist_dir),
58                &persist_dir,
59                Some("none"),
60                MsFlags::MS_BIND | MsFlags::MS_REC,
61                Some(""),
62            )
63            .map_err(|e| {
64                Error::MountError(
65                    format!(
66                        "-rbind {} to {}",
67                        persist_dir.display(),
68                        persist_dir.display()
69                    ),
70                    e,
71                )
72            })?;
73            made_netns_persist_dir_mount = true;
74        }
75
76        Ok(())
77    }
78}
79
80/// A default network namespace environment.
81///
82/// Its persistence directory is `/var/run/netns`, which is for consistency with the `ip-netns` tool.
83/// See [ip-netns](https://man7.org/linux/man-pages/man8/ip-netns.8.html) for details.
84#[derive(Copy, Clone, Default, Debug)]
85pub struct DefaultEnv;
86
87impl Env for DefaultEnv {
88    fn persist_dir(&self) -> PathBuf {
89        PathBuf::from("/var/run/netns")
90    }
91}
92
93/// A network namespace type.
94///
95/// It could be used to enter network namespace.
96#[derive(Debug)]
97pub struct NetNs<E: Env = DefaultEnv> {
98    file: File,
99    path: PathBuf,
100    env: Option<E>,
101}
102
103impl<E: Env> std::fmt::Display for NetNs<E> {
104    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
105        if let Ok(meta) = self.file.metadata() {
106            write!(
107                f,
108                "NetNS {{ fd: {}, dev: {}, ino: {}, path: {} }}",
109                self.file.as_raw_fd(),
110                meta.dev(),
111                meta.ino(),
112                self.path.display()
113            )
114        } else {
115            write!(
116                f,
117                "NetNS {{ fd: {}, path: {} }}",
118                self.file.as_raw_fd(),
119                self.path.display()
120            )
121        }
122    }
123}
124
125impl<E1: Env, E2: Env> PartialEq<NetNs<E1>> for NetNs<E2> {
126    fn eq(&self, other: &NetNs<E1>) -> bool {
127        if self.file.as_raw_fd() == other.file.as_raw_fd() {
128            return true;
129        }
130        let cmp_meta = |f1: &File, f2: &File| -> Option<bool> {
131            let m1 = match f1.metadata() {
132                Ok(m) => m,
133                Err(_) => return None,
134            };
135            let m2 = match f2.metadata() {
136                Ok(m) => m,
137                Err(_) => return None,
138            };
139            Some(m1.dev() == m2.dev() && m1.ino() == m2.ino())
140        };
141        cmp_meta(&self.file, &other.file).unwrap_or_else(|| self.path == other.path)
142    }
143}
144
145impl<E: Env> NetNs<E> {
146    /// Creates a new `NetNs` with the specified name and Env.
147    ///
148    /// The persist dir of network namespace will be created if it doesn't already exist.
149    pub fn new_with_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
150        env.init()?;
151
152        // create an empty file at the mount point
153        let ns_path = env.persist_dir().join(ns_name.as_ref());
154        let _ = File::create(&ns_path).map_err(Error::CreateNsError)?;
155        Self::persistent(&ns_path, true).map_err(|e| {
156            // Ensure the mount point is cleaned up on errors; if the namespace was successfully
157            // mounted this will have no effect because the file is in-use
158            std::fs::remove_file(&ns_path).ok();
159            e
160        })?;
161        Self::get_from_env(ns_name, env)
162    }
163
164    fn persistent<P: AsRef<Path>>(ns_path: &P, new_thread: bool) -> Result<()> {
165        if new_thread {
166            let ns_path_clone = ns_path.as_ref().to_path_buf();
167            let new_thread: JoinHandle<Result<()>> =
168                thread::spawn(move || Self::persistent(&ns_path_clone, false));
169            match new_thread.join() {
170                Ok(t) => {
171                    if let Err(e) = t {
172                        return Err(e);
173                    }
174                }
175                Err(e) => {
176                    return Err(Error::JoinThreadError(format!("{:?}", e)));
177                }
178            };
179        } else {
180            // Create a new netns for the current thread.
181            unshare(CloneFlags::CLONE_NEWNET).map_err(Error::UnshareError)?;
182            // bind mount the netns from the current thread (from /proc) onto the mount point.
183            // This persists the namespace, even when there are no threads in the ns.
184            let src = get_current_thread_netns_path();
185            mount(
186                Some(src.as_path()),
187                ns_path.as_ref(),
188                Some("none"),
189                MsFlags::MS_BIND,
190                Some(""),
191            )
192            .map_err(|e| {
193                Error::MountError(
194                    format!("rbind {} to {}", src.display(), ns_path.as_ref().display()),
195                    e,
196                )
197            })?;
198        }
199
200        Ok(())
201    }
202
203    /// Gets the path of this network namespace.
204    pub fn path(&self) -> &Path {
205        &self.path
206    }
207
208    /// Gets the Env of this network namespace.
209    pub fn env(&self) -> Option<&E> {
210        self.env.as_ref()
211    }
212
213    /// Gets the Env of this network namespace.
214    pub fn file(&self) -> &File {
215        &self.file
216    }
217
218    /// Makes the current thread enter this network namespace.
219    ///
220    /// Requires elevated privileges.
221    pub fn enter(&self) -> Result<()> {
222        setns(self.file.as_raw_fd(), CloneFlags::CLONE_NEWNET).map_err(Error::SetnsError)
223    }
224
225    /// Returns the NetNs with the specified name and Env.
226    pub fn get_from_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
227        let ns_path = env.persist_dir().join(ns_name.as_ref());
228        let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
229
230        Ok(Self {
231            file,
232            path: ns_path,
233            env: Some(env),
234        })
235    }
236
237    /// Removes this network namespace manually.
238    ///
239    /// Once called, this instance will not be available.
240    pub fn remove(self) -> Result<()> {
241        // need close first
242        nix::unistd::close(self.file.into_raw_fd()).map_err(Error::CloseNsError)?;
243        // Only unmount if it's been bind-mounted (don't touch namespaces in /proc...)
244        if let Some(env) = &self.env {
245            if env.contains(&self.path) {
246                Self::umount_ns(&self.path)?;
247            }
248        }
249        Ok(())
250    }
251
252    fn umount_ns<P: AsRef<Path>>(path: P) -> Result<()> {
253        let path = path.as_ref();
254        umount2(path, MntFlags::MNT_DETACH).map_err(|e| Error::UnmountError(path.to_owned(), e))?;
255        // Donot return error.
256        std::fs::remove_file(path).ok();
257        Ok(())
258    }
259
260    /// Run a closure in NetNs, which is specified by name and Env.
261    ///
262    /// Requires elevated privileges.
263    pub fn run<F, T>(&self, f: F) -> Result<T>
264    where
265        F: FnOnce(&Self) -> T,
266    {
267        // get current network namespace
268        let src_ns = get_from_current_thread()?;
269
270        // do nothing if ns_path is same as current_ns
271        if &src_ns == self {
272            return Ok(f(self));
273        }
274        // enter new namespace
275        self.enter()?;
276
277        let result = f(self);
278        // back to old namespace
279        src_ns.enter()?;
280
281        Ok(result)
282    }
283}
284
285impl NetNs {
286    /// Creates a new persistent (bind-mounted) network namespace and returns an object representing
287    /// that namespace, without switching to it.
288    ///
289    /// The persist directory of network namespace will be created if it doesn't already exist.
290    /// This function will use [`DefaultEnv`] to create persist directory.
291    ///
292    /// Requires elevated privileges.
293    ///
294    /// [`DefaultEnv`]: DefaultEnv
295    ///
296    pub fn new<S: AsRef<str>>(ns_name: S) -> Result<Self> {
297        Self::new_with_env(ns_name, DefaultEnv)
298    }
299
300    /// Returns the NetNs with the specified name and `DefaultEnv`.
301    pub fn get<S: AsRef<str>>(ns_name: S) -> Result<Self> {
302        Self::get_from_env(ns_name, DefaultEnv)
303    }
304
305    /// Run a closure in NetNs, which is specified by name and `DefaultEnv`.
306    ///
307    /// Requires elevated privileges.
308    pub fn run_in<S, F, T>(ns_name: S, f: F) -> Result<T>
309    where
310        S: AsRef<str>,
311        F: FnOnce(&Self) -> T,
312    {
313        // get network namespace
314        let run_ns = Self::get_from_env(ns_name, DefaultEnv)?;
315        run_ns.run(f)
316    }
317}
318
319/// Returns the NetNs with the spectified path.
320pub fn get_from_path<P: AsRef<Path>>(ns_path: P) -> Result<NetNs> {
321    let ns_path = ns_path.as_ref().to_path_buf();
322    let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
323
324    Ok(NetNs {
325        file,
326        path: ns_path,
327        env: None,
328    })
329}
330
331/// Returns the NetNs of current thread.
332pub fn get_from_current_thread() -> Result<NetNs> {
333    let ns_path = get_current_thread_netns_path();
334    let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
335
336    Ok(NetNs {
337        file,
338        path: ns_path,
339        env: None,
340    })
341}
342
343#[inline]
344fn get_current_thread_netns_path() -> PathBuf {
345    PathBuf::from(format!("/proc/self/task/{}/ns/net", gettid()))
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use std::os::unix::io::FromRawFd;
352
353    #[test]
354    fn test_netns_display() {
355        let ns = get_from_current_thread().unwrap();
356        let print = format!("{}", ns);
357        assert!(print.contains("dev"));
358        assert!(print.contains("ino"));
359
360        let ns: NetNs<DefaultEnv> = NetNs {
361            file: unsafe { File::from_raw_fd(i32::MAX) },
362            path: PathBuf::from(""),
363            env: None,
364        };
365        let print = format!("{}", ns);
366        assert!(!print.contains("dev"));
367        assert!(!print.contains("ino"));
368    }
369
370    #[test]
371    fn test_netns_eq() {
372        let ns1 = get_from_current_thread().unwrap();
373        let ns2 = get_from_path("/proc/self/ns/net").unwrap();
374        assert_eq!(ns1, ns2);
375
376        let ns1: NetNs<DefaultEnv> = NetNs {
377            file: unsafe { File::from_raw_fd(i32::MAX) },
378            path: PathBuf::from("aaaaaa"),
379            env: None,
380        };
381        let ns2: NetNs<DefaultEnv> = NetNs {
382            file: unsafe { File::from_raw_fd(i32::MAX) },
383            path: PathBuf::from("bbbbbb"),
384            env: None,
385        };
386        assert_eq!(ns1, ns2);
387
388        let ns2: NetNs<DefaultEnv> = NetNs {
389            file: unsafe { File::from_raw_fd(i32::MAX - 1) },
390            path: PathBuf::from("aaaaaa"),
391            env: None,
392        };
393        assert_eq!(ns1, ns2);
394    }
395
396    #[test]
397    fn test_netns_init() {
398        let ns = NetNs::new("test_netns_init").unwrap();
399        assert!(ns.path().exists());
400        ns.remove().unwrap();
401        assert!(!Path::new(&DefaultEnv.persist_dir())
402            .join("test_netns_init")
403            .exists());
404    }
405
406    struct TestNetNs {
407        netns: Option<NetNs>,
408        ns_name: String,
409    }
410
411    impl TestNetNs {
412        fn new(name: &str) -> Self {
413            let netns = NetNs::new(name).unwrap();
414            assert!(netns.path().exists());
415            Self {
416                netns: Some(netns),
417                ns_name: String::from(name),
418            }
419        }
420
421        fn netns(&self) -> &NetNs {
422            self.netns.as_ref().unwrap()
423        }
424    }
425
426    impl Drop for TestNetNs {
427        fn drop(&mut self) {
428            let ns_name = self.ns_name.clone();
429            self.netns.take().unwrap().remove().unwrap();
430            assert!(!Path::new(&DefaultEnv.persist_dir()).join(ns_name).exists());
431        }
432    }
433
434    #[test]
435    fn test_netns_enter() {
436        let new = TestNetNs::new("test_netns_enter");
437
438        let src = get_from_current_thread().unwrap();
439        assert_ne!(&src, new.netns());
440
441        new.netns().enter().unwrap();
442
443        let cur = get_from_current_thread().unwrap();
444
445        assert_eq!(new.netns(), &cur);
446        assert_ne!(src, cur);
447        assert_ne!(&src, new.netns());
448    }
449
450    struct TestEnv;
451    impl Env for TestEnv {
452        fn persist_dir(&self) -> PathBuf {
453            PathBuf::from("/tmp/test_netns")
454        }
455    }
456
457    #[test]
458    fn test_netns_with_env() {
459        let ns_res = NetNs::get_from_env("test_netns_run", TestEnv);
460        assert!(matches!(ns_res, Err(Error::OpenNsError(_, _))));
461
462        let ns = NetNs::new_with_env("test_netns_run", TestEnv).unwrap();
463        assert!(ns.path().exists());
464
465        ns.remove().unwrap();
466        assert!(!Path::new(&TestEnv.persist_dir())
467            .join("test_netns_set")
468            .exists());
469    }
470
471    #[test]
472    fn test_netns_run() {
473        let new = TestNetNs::new("test_netns_run");
474
475        let src_ns = get_from_current_thread().unwrap();
476
477        let ret = new
478            .netns()
479            .run(|cur_ns| -> Result<()> {
480                let cur_thread = get_from_current_thread().unwrap();
481                assert_eq!(cur_ns, &cur_thread);
482                // captured variables
483                assert_eq!(cur_ns, new.netns());
484                assert_ne!(cur_ns, &src_ns);
485
486                Ok(())
487            })
488            .unwrap();
489        assert!(matches!(ret, Ok(_)));
490    }
491}