Skip to main content

netns_rs/
netns.rs

1// Copyright 2022 Alibaba Cloud. All Rights Reserved.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5
6use std::fs::File;
7use std::os::unix::fs::MetadataExt;
8use std::os::unix::io::AsRawFd;
9use std::path::{Path, PathBuf};
10use std::thread::{self, JoinHandle};
11
12use nix::mount::{mount, umount2, MntFlags, MsFlags};
13use nix::sched::{setns, unshare, CloneFlags};
14use nix::unistd::gettid;
15
16use crate::{Error, Result};
17
18/// Defines a NetNs environment behavior.
19pub trait Env {
20    /// The persist directory of the NetNs environment.
21    fn persist_dir(&self) -> PathBuf;
22
23    /// Returns `true` if the given path is in this Env.
24    fn contains<P: AsRef<Path>>(&self, p: P) -> bool {
25        p.as_ref().starts_with(self.persist_dir())
26    }
27
28    /// Initialize the environment.
29    fn init(&self) -> Result<()> {
30        // Create the directory for mounting network namespaces.
31        // This needs to be a shared mountpoint in case it is mounted in to
32        // other namespaces (containers)
33        let persist_dir = self.persist_dir();
34        std::fs::create_dir_all(&persist_dir).map_err(Error::CreateNsDirError)?;
35
36        // Remount the namespace directory shared. This will fail if it is not
37        // already a mountpoint, so bind-mount it on to itself to "upgrade" it
38        // to a mountpoint.
39        let mut made_netns_persist_dir_mount: bool = false;
40        while let Err(e) = mount(
41            Some(""),
42            &persist_dir,
43            Some("none"),
44            MsFlags::MS_SHARED | MsFlags::MS_REC,
45            Some(""),
46        ) {
47            // Fail unless we need to make the mount point
48            if e != nix::errno::Errno::EINVAL || made_netns_persist_dir_mount {
49                return Err(Error::MountError(
50                    format!("--make-rshared {}", persist_dir.display()),
51                    e,
52                ));
53            }
54            // Recursively remount /var/persist/netns on itself. The recursive flag is
55            // so that any existing netns bindmounts are carried over.
56            mount(
57                Some(&persist_dir),
58                &persist_dir,
59                Some("none"),
60                MsFlags::MS_BIND | MsFlags::MS_REC,
61                Some(""),
62            )
63            .map_err(|e| {
64                Error::MountError(
65                    format!(
66                        "-rbind {} to {}",
67                        persist_dir.display(),
68                        persist_dir.display()
69                    ),
70                    e,
71                )
72            })?;
73            made_netns_persist_dir_mount = true;
74        }
75
76        Ok(())
77    }
78}
79
80/// A default network namespace environment.
81///
82/// Its persistence directory is `/var/run/netns`, which is for consistency with the `ip-netns` tool.
83/// See [ip-netns](https://man7.org/linux/man-pages/man8/ip-netns.8.html) for details.
84#[derive(Copy, Clone, Default, Debug)]
85pub struct DefaultEnv;
86
87impl Env for DefaultEnv {
88    fn persist_dir(&self) -> PathBuf {
89        PathBuf::from("/var/run/netns")
90    }
91}
92
93/// A network namespace type.
94///
95/// It could be used to enter network namespace.
96#[derive(Debug)]
97pub struct NetNs<E: Env = DefaultEnv> {
98    file: File,
99    path: PathBuf,
100    env: Option<E>,
101}
102
103impl<E: Env> std::fmt::Display for NetNs<E> {
104    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
105        if let Ok(meta) = self.file.metadata() {
106            write!(
107                f,
108                "NetNS {{ fd: {}, dev: {}, ino: {}, path: {} }}",
109                self.file.as_raw_fd(),
110                meta.dev(),
111                meta.ino(),
112                self.path.display()
113            )
114        } else {
115            write!(
116                f,
117                "NetNS {{ fd: {}, path: {} }}",
118                self.file.as_raw_fd(),
119                self.path.display()
120            )
121        }
122    }
123}
124
125impl<E1: Env, E2: Env> PartialEq<NetNs<E1>> for NetNs<E2> {
126    fn eq(&self, other: &NetNs<E1>) -> bool {
127        if self.file.as_raw_fd() == other.file.as_raw_fd() {
128            return true;
129        }
130        let cmp_meta = |f1: &File, f2: &File| -> Option<bool> {
131            let m1 = match f1.metadata() {
132                Ok(m) => m,
133                Err(_) => return None,
134            };
135            let m2 = match f2.metadata() {
136                Ok(m) => m,
137                Err(_) => return None,
138            };
139            Some(m1.dev() == m2.dev() && m1.ino() == m2.ino())
140        };
141        cmp_meta(&self.file, &other.file).unwrap_or_else(|| self.path == other.path)
142    }
143}
144
145impl<E: Env> NetNs<E> {
146    /// Creates a new `NetNs` with the specified name and Env.
147    ///
148    /// The persist dir of network namespace will be created if it doesn't already exist.
149    pub fn new_with_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
150        env.init()?;
151
152        // create an empty file at the mount point
153        let ns_path = env.persist_dir().join(ns_name.as_ref());
154        let _ = File::create(&ns_path).map_err(Error::CreateNsError)?;
155        Self::persistent(&ns_path, true).inspect_err(|_e| {
156            // Ensure the mount point is cleaned up on errors; if the namespace was successfully
157            // mounted this will have no effect because the file is in-use
158            std::fs::remove_file(&ns_path).ok();
159        })?;
160        Self::get_from_env(ns_name, env)
161    }
162
163    fn persistent<P: AsRef<Path>>(ns_path: &P, new_thread: bool) -> Result<()> {
164        if new_thread {
165            let ns_path_clone = ns_path.as_ref().to_path_buf();
166            let new_thread: JoinHandle<Result<()>> =
167                thread::spawn(move || Self::persistent(&ns_path_clone, false));
168            match new_thread.join() {
169                Ok(t) => t?,
170                Err(e) => {
171                    return Err(Error::JoinThreadError(format!("{:?}", e)));
172                }
173            };
174        } else {
175            // Create a new netns for the current thread.
176            unshare(CloneFlags::CLONE_NEWNET).map_err(Error::UnshareError)?;
177            // bind mount the netns from the current thread (from /proc) onto the mount point.
178            // This persists the namespace, even when there are no threads in the ns.
179            let src = get_current_thread_netns_path();
180            mount(
181                Some(src.as_path()),
182                ns_path.as_ref(),
183                Some("none"),
184                MsFlags::MS_BIND,
185                Some(""),
186            )
187            .map_err(|e| {
188                Error::MountError(
189                    format!("rbind {} to {}", src.display(), ns_path.as_ref().display()),
190                    e,
191                )
192            })?;
193        }
194
195        Ok(())
196    }
197
198    /// Gets the path of this network namespace.
199    pub fn path(&self) -> &Path {
200        &self.path
201    }
202
203    /// Gets the Env of this network namespace.
204    pub fn env(&self) -> Option<&E> {
205        self.env.as_ref()
206    }
207
208    /// Gets the Env of this network namespace.
209    pub fn file(&self) -> &File {
210        &self.file
211    }
212
213    /// Makes the current thread enter this network namespace.
214    ///
215    /// Requires elevated privileges.
216    pub fn enter(&self) -> Result<()> {
217        setns(&self.file, CloneFlags::CLONE_NEWNET).map_err(Error::SetnsError)
218    }
219
220    /// Returns the NetNs with the specified name and Env.
221    pub fn get_from_env<S: AsRef<str>>(ns_name: S, env: E) -> Result<Self> {
222        let ns_path = env.persist_dir().join(ns_name.as_ref());
223        let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
224
225        Ok(Self {
226            file,
227            path: ns_path,
228            env: Some(env),
229        })
230    }
231
232    /// Removes this network namespace manually.
233    ///
234    /// Once called, this instance will not be available.
235    pub fn remove(self) -> Result<()> {
236        // Close the file descriptor by dropping it.
237        drop(self.file);
238        // Only unmount if it's been bind-mounted (don't touch namespaces in /proc...)
239        if let Some(env) = &self.env {
240            if env.contains(&self.path) {
241                Self::umount_ns(&self.path)?;
242            }
243        }
244        Ok(())
245    }
246
247    fn umount_ns<P: AsRef<Path>>(path: P) -> Result<()> {
248        let path = path.as_ref();
249        umount2(path, MntFlags::MNT_DETACH).map_err(|e| Error::UnmountError(path.to_owned(), e))?;
250        // Donot return error.
251        std::fs::remove_file(path).ok();
252        Ok(())
253    }
254
255    /// Run a closure in NetNs, which is specified by name and Env.
256    ///
257    /// Requires elevated privileges.
258    pub fn run<F, T>(&self, f: F) -> Result<T>
259    where
260        F: FnOnce(&Self) -> T,
261    {
262        // get current network namespace
263        let src_ns = get_from_current_thread()?;
264
265        // do nothing if ns_path is same as current_ns
266        if &src_ns == self {
267            return Ok(f(self));
268        }
269        // enter new namespace
270        self.enter()?;
271
272        let result = f(self);
273        // back to old namespace
274        src_ns.enter()?;
275
276        Ok(result)
277    }
278}
279
280impl NetNs {
281    /// Creates a new persistent (bind-mounted) network namespace and returns an object representing
282    /// that namespace, without switching to it.
283    ///
284    /// The persist directory of network namespace will be created if it doesn't already exist.
285    /// This function will use [`DefaultEnv`] to create persist directory.
286    ///
287    /// Requires elevated privileges.
288    ///
289    /// [`DefaultEnv`]: DefaultEnv
290    ///
291    pub fn new<S: AsRef<str>>(ns_name: S) -> Result<Self> {
292        Self::new_with_env(ns_name, DefaultEnv)
293    }
294
295    /// Returns the NetNs with the specified name and `DefaultEnv`.
296    pub fn get<S: AsRef<str>>(ns_name: S) -> Result<Self> {
297        Self::get_from_env(ns_name, DefaultEnv)
298    }
299
300    /// Run a closure in NetNs, which is specified by name and `DefaultEnv`.
301    ///
302    /// Requires elevated privileges.
303    pub fn run_in<S, F, T>(ns_name: S, f: F) -> Result<T>
304    where
305        S: AsRef<str>,
306        F: FnOnce(&Self) -> T,
307    {
308        // get network namespace
309        let run_ns = Self::get_from_env(ns_name, DefaultEnv)?;
310        run_ns.run(f)
311    }
312}
313
314/// Returns the NetNs with the spectified path.
315pub fn get_from_path<P: AsRef<Path>>(ns_path: P) -> Result<NetNs> {
316    let ns_path = ns_path.as_ref().to_path_buf();
317    let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
318
319    Ok(NetNs {
320        file,
321        path: ns_path,
322        env: None,
323    })
324}
325
326/// Returns the NetNs of current thread.
327pub fn get_from_current_thread() -> Result<NetNs> {
328    let ns_path = get_current_thread_netns_path();
329    let file = File::open(&ns_path).map_err(|e| Error::OpenNsError(ns_path.clone(), e))?;
330
331    Ok(NetNs {
332        file,
333        path: ns_path,
334        env: None,
335    })
336}
337
338#[inline]
339fn get_current_thread_netns_path() -> PathBuf {
340    PathBuf::from(format!("/proc/self/task/{}/ns/net", gettid()))
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use std::mem::ManuallyDrop;
347    use std::os::unix::io::FromRawFd;
348
349    fn make_dummy_netns(fd: i32, path: &str) -> ManuallyDrop<NetNs<DefaultEnv>> {
350        ManuallyDrop::new(NetNs {
351            file: unsafe { File::from_raw_fd(fd) },
352            path: PathBuf::from(path),
353            env: None,
354        })
355    }
356
357    #[test]
358    fn test_netns_display() {
359        let ns = get_from_current_thread().unwrap();
360        let print = format!("{}", ns);
361        assert!(print.contains("dev"));
362        assert!(print.contains("ino"));
363
364        let ns = make_dummy_netns(i32::MAX, "");
365        let print = format!("{}", *ns);
366        assert!(!print.contains("dev"));
367        assert!(!print.contains("ino"));
368    }
369
370    #[test]
371    fn test_netns_eq() {
372        let ns1 = get_from_current_thread().unwrap();
373        let ns2 = get_from_path("/proc/self/ns/net").unwrap();
374        assert_eq!(ns1, ns2);
375
376        let ns1 = make_dummy_netns(i32::MAX, "aaaaaa");
377        let ns2 = make_dummy_netns(i32::MAX, "bbbbbb");
378        assert_eq!(*ns1, *ns2);
379
380        let ns2 = make_dummy_netns(i32::MAX - 1, "aaaaaa");
381        assert_eq!(*ns1, *ns2);
382    }
383
384    #[test]
385    fn test_netns_init() {
386        let ns = NetNs::new("test_netns_init").unwrap();
387        assert!(ns.path().exists());
388        ns.remove().unwrap();
389        assert!(!Path::new(&DefaultEnv.persist_dir())
390            .join("test_netns_init")
391            .exists());
392    }
393
394    struct TestNetNs {
395        netns: Option<NetNs>,
396        ns_name: String,
397    }
398
399    impl TestNetNs {
400        fn new(name: &str) -> Self {
401            let netns = NetNs::new(name).unwrap();
402            assert!(netns.path().exists());
403            Self {
404                netns: Some(netns),
405                ns_name: String::from(name),
406            }
407        }
408
409        fn netns(&self) -> &NetNs {
410            self.netns.as_ref().unwrap()
411        }
412    }
413
414    impl Drop for TestNetNs {
415        fn drop(&mut self) {
416            let ns_name = self.ns_name.clone();
417            self.netns.take().unwrap().remove().unwrap();
418            assert!(!Path::new(&DefaultEnv.persist_dir()).join(ns_name).exists());
419        }
420    }
421
422    #[test]
423    fn test_netns_enter() {
424        let new = TestNetNs::new("test_netns_enter");
425
426        let src = get_from_current_thread().unwrap();
427        assert_ne!(&src, new.netns());
428
429        new.netns().enter().unwrap();
430
431        let cur = get_from_current_thread().unwrap();
432
433        assert_eq!(new.netns(), &cur);
434        assert_ne!(src, cur);
435        assert_ne!(&src, new.netns());
436    }
437
438    struct TestEnv;
439    impl Env for TestEnv {
440        fn persist_dir(&self) -> PathBuf {
441            PathBuf::from("/tmp/test_netns")
442        }
443    }
444
445    #[test]
446    fn test_netns_with_env() {
447        let ns_res = NetNs::get_from_env("test_netns_run", TestEnv);
448        assert!(matches!(ns_res, Err(Error::OpenNsError(_, _))));
449
450        let ns = NetNs::new_with_env("test_netns_run", TestEnv).unwrap();
451        assert!(ns.path().exists());
452
453        ns.remove().unwrap();
454        assert!(!Path::new(&TestEnv.persist_dir())
455            .join("test_netns_set")
456            .exists());
457    }
458
459    #[test]
460    fn test_netns_run() {
461        let new = TestNetNs::new("test_netns_run");
462
463        let src_ns = get_from_current_thread().unwrap();
464
465        let ret = new
466            .netns()
467            .run(|cur_ns| -> Result<()> {
468                let cur_thread = get_from_current_thread().unwrap();
469                assert_eq!(cur_ns, &cur_thread);
470                // captured variables
471                assert_eq!(cur_ns, new.netns());
472                assert_ne!(cur_ns, &src_ns);
473
474                Ok(())
475            })
476            .unwrap();
477        assert!(ret.is_ok());
478    }
479}