Skip to main content

libfuse_fs/util/
bind_mount.rs

1// Copyright (C) 2024 rk8s authors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3//! Bind mount utilities for container volume management
4
5use std::io::{Error, Result};
6use std::path::{Path, PathBuf};
7use std::sync::Arc;
8use tokio::sync::Mutex;
9use tracing::{debug, error, info};
10
11/// Represents a single bind mount
12#[derive(Debug, Clone)]
13pub struct BindMount {
14    /// Source path on host
15    pub source: PathBuf,
16    /// Target path relative to mount point
17    pub target: PathBuf,
18}
19
20impl BindMount {
21    /// Parse a bind mount specification like "proc:/proc" or "/host/path:/container/path"
22    pub fn parse(spec: &str) -> Result<Self> {
23        let parts: Vec<&str> = spec.split(':').collect();
24        if parts.len() != 2 {
25            return Err(Error::other(format!(
26                "Invalid bind mount spec: '{}'. Expected format: 'source:target'",
27                spec
28            )));
29        }
30
31        let source = PathBuf::from(parts[0]);
32        let target = PathBuf::from(parts[1]);
33
34        // Convert relative source paths to absolute from root
35        let source = if source.is_relative() {
36            PathBuf::from("/").join(source)
37        } else {
38            source
39        };
40
41        Ok(BindMount { source, target })
42    }
43}
44
45/// Manages multiple bind mounts with automatic cleanup
46pub struct BindMountManager {
47    mounts: Arc<Mutex<Vec<MountPoint>>>,
48    mountpoint: PathBuf,
49}
50
51#[derive(Debug)]
52struct MountPoint {
53    target: PathBuf,
54    mounted: bool,
55}
56
57impl BindMountManager {
58    /// Create a new bind mount manager
59    pub fn new<P: AsRef<Path>>(mountpoint: P) -> Self {
60        Self {
61            mounts: Arc::new(Mutex::new(Vec::new())),
62            mountpoint: mountpoint.as_ref().to_path_buf(),
63        }
64    }
65
66    /// Mount all bind mounts
67    pub async fn mount_all(&self, bind_specs: &[BindMount]) -> Result<()> {
68        let mut mounts = self.mounts.lock().await;
69
70        for bind in bind_specs {
71            let target_path = self
72                .mountpoint
73                .join(bind.target.strip_prefix("/").unwrap_or(&bind.target));
74
75            // Check if source is a file or directory
76            let source_metadata = std::fs::metadata(&bind.source)?;
77
78            if !target_path.exists() {
79                if source_metadata.is_file() {
80                    // For file bind mounts, create parent directory and an empty file
81                    if let Some(parent) = target_path.parent() {
82                        std::fs::create_dir_all(parent)?;
83                        debug!("Created parent directory: {:?}", parent);
84                    }
85                    std::fs::File::create(&target_path)?;
86                    debug!("Created target file: {:?}", target_path);
87                } else {
88                    // For directory bind mounts, create the directory
89                    std::fs::create_dir_all(&target_path)?;
90                    debug!("Created target directory: {:?}", target_path);
91                }
92            }
93
94            // Perform the bind mount
95            self.do_mount(&bind.source, &target_path)?;
96
97            mounts.push(MountPoint {
98                target: target_path.clone(),
99                mounted: true,
100            });
101
102            info!("Bind mounted {:?} -> {:?}", bind.source, target_path);
103        }
104
105        Ok(())
106    }
107
108    /// Perform the actual bind mount using mount(2) syscall
109    fn do_mount(&self, source: &Path, target: &Path) -> Result<()> {
110        use std::ffi::CString;
111
112        let source_cstr = CString::new(
113            source
114                .to_str()
115                .ok_or_else(|| Error::other(format!("Invalid source path: {:?}", source)))?,
116        )
117        .map_err(|e| Error::other(format!("CString error: {}", e)))?;
118
119        let target_cstr = CString::new(
120            target
121                .to_str()
122                .ok_or_else(|| Error::other(format!("Invalid target path: {:?}", target)))?,
123        )
124        .map_err(|e| Error::other(format!("CString error: {}", e)))?;
125
126        let fstype = CString::new("none").unwrap();
127
128        let ret = unsafe {
129            libc::mount(
130                source_cstr.as_ptr(),
131                target_cstr.as_ptr(),
132                fstype.as_ptr(),
133                libc::MS_BIND | libc::MS_REC,
134                std::ptr::null(),
135            )
136        };
137
138        if ret != 0 {
139            let err = Error::last_os_error();
140            error!("Failed to bind mount {:?} to {:?}: {}", source, target, err);
141            return Err(err);
142        }
143
144        // Prevent mount propagation issues by making the mount point a slave.
145        // This ensures that unmounting the target doesn't propagate back to the host/source
146        // if they are part of a shared subtree (which is common on modern Linux).
147        let ret = unsafe {
148            libc::mount(
149                std::ptr::null(),
150                target_cstr.as_ptr(),
151                std::ptr::null(),
152                libc::MS_SLAVE | libc::MS_REC,
153                std::ptr::null(),
154            )
155        };
156
157        if ret != 0 {
158            let err = Error::last_os_error();
159            error!("Failed to set mount propagation for {:?}: {}", target, err);
160            // Attempt cleanup
161            unsafe { libc::umount2(target_cstr.as_ptr(), libc::MNT_DETACH) };
162            return Err(err);
163        }
164
165        Ok(())
166    }
167
168    /// Unmount all bind mounts
169    pub async fn unmount_all(&self) -> Result<()> {
170        let mut mounts = self.mounts.lock().await;
171        let mut errors = Vec::new();
172
173        // Unmount in reverse order
174        while let Some(mut mount) = mounts.pop() {
175            if mount.mounted {
176                if let Err(e) = self.do_unmount(&mount.target) {
177                    error!("Failed to unmount {:?}: {}", mount.target, e);
178                    errors.push(e);
179                } else {
180                    mount.mounted = false;
181                    info!("Unmounted {:?}", mount.target);
182                }
183            }
184        }
185
186        if !errors.is_empty() {
187            return Err(Error::other(format!(
188                "Failed to unmount {} bind mounts",
189                errors.len()
190            )));
191        }
192
193        Ok(())
194    }
195
196    /// Perform the actual unmount using umount(2) syscall
197    fn do_unmount(&self, target: &Path) -> Result<()> {
198        use std::ffi::CString;
199
200        let target_cstr = CString::new(
201            target
202                .to_str()
203                .ok_or_else(|| Error::other(format!("Invalid target path: {:?}", target)))?,
204        )
205        .map_err(|e| Error::other(format!("CString error: {}", e)))?;
206
207        let ret = unsafe { libc::umount2(target_cstr.as_ptr(), libc::MNT_DETACH) };
208
209        if ret != 0 {
210            let err = Error::last_os_error();
211            // EINVAL or ENOENT might mean it's already unmounted
212            if err.raw_os_error() != Some(libc::EINVAL) && err.raw_os_error() != Some(libc::ENOENT)
213            {
214                return Err(err);
215            }
216        }
217
218        Ok(())
219    }
220}
221
222impl Drop for BindMountManager {
223    fn drop(&mut self) {
224        // Attempt to clean up on drop (synchronously)
225        let mounts = self.mounts.try_lock();
226        if let Ok(mut mounts) = mounts {
227            while let Some(mount) = mounts.pop() {
228                if mount.mounted {
229                    let _ = self.do_unmount(&mount.target);
230                }
231            }
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_parse_bind_mount() {
242        let bind = BindMount::parse("proc:/proc").unwrap();
243        assert_eq!(bind.source, PathBuf::from("/proc"));
244        assert_eq!(bind.target, PathBuf::from("/proc"));
245
246        let bind = BindMount::parse("/host/path:/container/path").unwrap();
247        assert_eq!(bind.source, PathBuf::from("/host/path"));
248        assert_eq!(bind.target, PathBuf::from("/container/path"));
249
250        let bind = BindMount::parse("sys:/sys").unwrap();
251        assert_eq!(bind.source, PathBuf::from("/sys"));
252        assert_eq!(bind.target, PathBuf::from("/sys"));
253    }
254
255    #[test]
256    fn test_invalid_bind_mount() {
257        assert!(BindMount::parse("invalid").is_err());
258        assert!(BindMount::parse("too:many:colons").is_err());
259    }
260}