Skip to main content

claude_agent/security/path/
resolver.rs

1//! TOCTOU-safe path resolution using openat() with O_NOFOLLOW.
2
3use std::ffi::{CString, OsStr, OsString};
4use std::os::unix::ffi::OsStrExt;
5use std::os::unix::io::{AsFd, BorrowedFd, OwnedFd};
6use std::path::{Component, Path, PathBuf};
7use std::sync::Arc;
8
9use rustix::fs::{Mode, OFlags, openat};
10use rustix::io::Errno;
11
12use crate::security::SecurityError;
13
14#[derive(Debug)]
15pub struct SafePath {
16    root_fd: Arc<OwnedFd>,
17    root_path: PathBuf,
18    components: Vec<OsString>,
19    resolved_path: PathBuf,
20    permissive: bool,
21}
22
23impl SafePath {
24    pub fn resolve(
25        root_fd: Arc<OwnedFd>,
26        root_path: PathBuf,
27        relative_path: &Path,
28        max_symlink_depth: u8,
29    ) -> Result<Self, SecurityError> {
30        let mut components = Vec::new();
31        let mut symlink_depth = 0u8;
32
33        for component in relative_path.components() {
34            match component {
35                Component::ParentDir => {
36                    if components.is_empty() {
37                        return Err(SecurityError::PathEscape(relative_path.to_path_buf()));
38                    }
39                    components.pop();
40                }
41                Component::CurDir | Component::RootDir => {}
42                Component::Normal(name) => {
43                    components.push(name.to_os_string());
44                }
45                Component::Prefix(_) => {}
46            }
47        }
48
49        let mut validated_components = Vec::new();
50        let mut current_fd: BorrowedFd<'_> = root_fd.as_fd();
51        let mut owned_fds: Vec<OwnedFd> = Vec::new();
52
53        for (i, component) in components.iter().enumerate() {
54            let is_last = i == components.len() - 1;
55
56            let c_name = CString::new(component.as_bytes())
57                .map_err(|_| SecurityError::InvalidPath("null byte in path".into()))?;
58
59            let flags = if is_last {
60                OFlags::RDONLY | OFlags::NOFOLLOW | OFlags::CLOEXEC
61            } else {
62                OFlags::RDONLY | OFlags::DIRECTORY | OFlags::NOFOLLOW | OFlags::CLOEXEC
63            };
64
65            match openat(current_fd, &c_name, flags, Mode::empty()) {
66                Ok(fd) => {
67                    validated_components.push(component.clone());
68                    if !is_last {
69                        owned_fds.push(fd);
70                        // SAFETY: We just pushed to owned_fds, so last() is guaranteed to be Some
71                        current_fd = owned_fds
72                            .last()
73                            .expect("owned_fds is non-empty after push")
74                            .as_fd();
75                    } else {
76                        drop(fd);
77                    }
78                }
79                Err(Errno::LOOP) | Err(Errno::MLINK) => {
80                    symlink_depth += 1;
81                    if symlink_depth > max_symlink_depth {
82                        return Err(SecurityError::SymlinkDepthExceeded {
83                            path: relative_path.to_path_buf(),
84                            max: max_symlink_depth,
85                        });
86                    }
87
88                    let target = rustix::fs::readlinkat(current_fd, &c_name, vec![0u8; 4096])
89                        .map_err(|e| {
90                            SecurityError::Io(std::io::Error::from_raw_os_error(e.raw_os_error()))
91                        })?;
92
93                    let target_path = PathBuf::from(OsStr::from_bytes(target.to_bytes()));
94                    if target_path.is_absolute() {
95                        if !target_path.starts_with(&root_path) {
96                            return Err(SecurityError::AbsoluteSymlink(target_path));
97                        }
98                        let relative = target_path
99                            .strip_prefix(&root_path)
100                            .expect("path verified with starts_with");
101                        return Self::resolve(
102                            Arc::clone(&root_fd),
103                            root_path,
104                            relative,
105                            max_symlink_depth - symlink_depth,
106                        );
107                    }
108
109                    let mut remaining: Vec<OsString> = target_path
110                        .components()
111                        .filter_map(|c| match c {
112                            Component::Normal(s) => Some(s.to_os_string()),
113                            _ => None,
114                        })
115                        .collect();
116
117                    remaining.extend(components.iter().skip(i + 1).cloned());
118
119                    let new_path: PathBuf = remaining.iter().collect();
120                    let current_path: PathBuf = validated_components.iter().collect();
121                    let full_path = current_path.join(&new_path);
122
123                    return Self::resolve(
124                        Arc::clone(&root_fd),
125                        root_path,
126                        &full_path,
127                        max_symlink_depth - symlink_depth,
128                    );
129                }
130                Err(Errno::NOENT) => {
131                    validated_components.push(component.clone());
132                    validated_components.extend(components.iter().skip(i + 1).cloned());
133                    break;
134                }
135                Err(e) => {
136                    return Err(SecurityError::Io(std::io::Error::from_raw_os_error(
137                        e.raw_os_error(),
138                    )));
139                }
140            }
141        }
142
143        let resolved_path = root_path.join(validated_components.iter().collect::<PathBuf>());
144
145        Ok(Self {
146            root_fd,
147            root_path,
148            components: validated_components,
149            resolved_path,
150            permissive: false,
151        })
152    }
153
154    /// Create a SafePath without validation (for permissive mode).
155    /// This bypasses TOCTOU protection but allows symlinks.
156    pub fn unchecked(root_fd: Arc<OwnedFd>, resolved_path: PathBuf) -> Self {
157        let root_path = PathBuf::from("/");
158        let components = resolved_path
159            .strip_prefix("/")
160            .unwrap_or(&resolved_path)
161            .components()
162            .filter_map(|c| match c {
163                Component::Normal(s) => Some(s.to_os_string()),
164                _ => None,
165            })
166            .collect();
167
168        Self {
169            root_fd,
170            root_path,
171            components,
172            resolved_path,
173            permissive: true,
174        }
175    }
176
177    pub fn is_permissive(&self) -> bool {
178        self.permissive
179    }
180
181    pub fn root_fd(&self) -> BorrowedFd<'_> {
182        self.root_fd.as_fd()
183    }
184
185    pub fn root_path(&self) -> &Path {
186        &self.root_path
187    }
188
189    pub fn components(&self) -> &[OsString] {
190        &self.components
191    }
192
193    pub fn as_path(&self) -> &Path {
194        &self.resolved_path
195    }
196
197    pub fn filename(&self) -> Option<&OsStr> {
198        self.components.last().map(|s| s.as_os_str())
199    }
200
201    pub fn parent_components(&self) -> &[OsString] {
202        if self.components.is_empty() {
203            &[]
204        } else {
205            &self.components[..self.components.len() - 1]
206        }
207    }
208
209    pub fn open(&self, flags: OFlags) -> Result<OwnedFd, SecurityError> {
210        // In permissive mode, use standard library to handle symlinks
211        if self.permissive {
212            use std::fs::OpenOptions;
213            use std::os::unix::fs::OpenOptionsExt;
214
215            let mut opts = OpenOptions::new();
216
217            if flags.contains(OFlags::RDONLY) && !flags.contains(OFlags::WRONLY) {
218                opts.read(true);
219            }
220            if flags.contains(OFlags::WRONLY) || flags.contains(OFlags::RDWR) {
221                opts.write(true);
222            }
223            if flags.contains(OFlags::RDWR) {
224                opts.read(true);
225            }
226            if flags.contains(OFlags::CREATE) {
227                opts.create(true);
228            }
229            if flags.contains(OFlags::TRUNC) {
230                opts.truncate(true);
231            }
232            if flags.contains(OFlags::APPEND) {
233                opts.append(true);
234            }
235
236            opts.mode(0o644);
237
238            let file = opts.open(&self.resolved_path).map_err(SecurityError::Io)?;
239            return Ok(file.into());
240        }
241
242        if self.components.is_empty() {
243            let fd = rustix::fs::openat(
244                self.root_fd.as_fd(),
245                c".",
246                flags | OFlags::CLOEXEC,
247                Mode::empty(),
248            )
249            .map_err(|e| SecurityError::Io(std::io::Error::from_raw_os_error(e.raw_os_error())))?;
250            return Ok(fd);
251        }
252
253        let mut current_fd: BorrowedFd<'_> = self.root_fd.as_fd();
254        let mut owned_fds: Vec<OwnedFd> = Vec::new();
255
256        for (i, component) in self.components.iter().enumerate() {
257            let is_last = i == self.components.len() - 1;
258            let c_name = CString::new(component.as_bytes())
259                .map_err(|_| SecurityError::InvalidPath("null byte".into()))?;
260
261            let open_flags = if is_last {
262                flags | OFlags::NOFOLLOW | OFlags::CLOEXEC
263            } else {
264                OFlags::RDONLY | OFlags::DIRECTORY | OFlags::NOFOLLOW | OFlags::CLOEXEC
265            };
266
267            let fd = openat(current_fd, &c_name, open_flags, Mode::from_raw_mode(0o644)).map_err(
268                |e| SecurityError::Io(std::io::Error::from_raw_os_error(e.raw_os_error())),
269            )?;
270
271            if is_last {
272                return Ok(fd);
273            }
274
275            owned_fds.push(fd);
276            // SAFETY: We just pushed to owned_fds, so last() is guaranteed to be Some
277            current_fd = owned_fds
278                .last()
279                .expect("owned_fds is non-empty after push")
280                .as_fd();
281        }
282
283        unreachable!("loop always returns on is_last")
284    }
285
286    pub fn create_parent_dirs(&self) -> Result<(), SecurityError> {
287        if self.components.len() <= 1 {
288            return Ok(());
289        }
290
291        // In permissive mode, use standard library to handle symlinks
292        if self.permissive {
293            if let Some(parent) = self.resolved_path.parent() {
294                std::fs::create_dir_all(parent)?;
295            }
296            return Ok(());
297        }
298
299        let mut current_fd: BorrowedFd<'_> = self.root_fd.as_fd();
300        let mut owned_fds: Vec<OwnedFd> = Vec::new();
301
302        for component in self.parent_components() {
303            let c_name = CString::new(component.as_bytes())
304                .map_err(|_| SecurityError::InvalidPath("null byte".into()))?;
305
306            match openat(
307                current_fd,
308                &c_name,
309                OFlags::RDONLY | OFlags::DIRECTORY | OFlags::NOFOLLOW | OFlags::CLOEXEC,
310                Mode::empty(),
311            ) {
312                Ok(fd) => {
313                    owned_fds.push(fd);
314                    // SAFETY: We just pushed to owned_fds, so last() is guaranteed to be Some
315                    current_fd = owned_fds
316                        .last()
317                        .expect("owned_fds is non-empty after push")
318                        .as_fd();
319                }
320                Err(Errno::NOENT) => {
321                    rustix::fs::mkdirat(current_fd, &c_name, Mode::from_raw_mode(0o755)).map_err(
322                        |e| SecurityError::Io(std::io::Error::from_raw_os_error(e.raw_os_error())),
323                    )?;
324
325                    let fd = openat(
326                        current_fd,
327                        &c_name,
328                        OFlags::RDONLY | OFlags::DIRECTORY | OFlags::CLOEXEC,
329                        Mode::empty(),
330                    )
331                    .map_err(|e| {
332                        SecurityError::Io(std::io::Error::from_raw_os_error(e.raw_os_error()))
333                    })?;
334
335                    owned_fds.push(fd);
336                    // SAFETY: We just pushed to owned_fds, so last() is guaranteed to be Some
337                    current_fd = owned_fds
338                        .last()
339                        .expect("owned_fds is non-empty after push")
340                        .as_fd();
341                }
342                Err(e) => {
343                    return Err(SecurityError::Io(std::io::Error::from_raw_os_error(
344                        e.raw_os_error(),
345                    )));
346                }
347            }
348        }
349
350        Ok(())
351    }
352}
353
354impl Clone for SafePath {
355    fn clone(&self) -> Self {
356        Self {
357            root_fd: Arc::clone(&self.root_fd),
358            root_path: self.root_path.clone(),
359            components: self.components.clone(),
360            resolved_path: self.resolved_path.clone(),
361            permissive: self.permissive,
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use std::fs;
370    use tempfile::tempdir;
371
372    fn open_dir(path: &Path) -> Arc<OwnedFd> {
373        let fd = std::fs::File::open(path).unwrap();
374        Arc::new(fd.into())
375    }
376
377    #[test]
378    fn test_resolve_simple() {
379        let dir = tempdir().unwrap();
380        let root = std::fs::canonicalize(dir.path()).unwrap();
381        fs::write(root.join("test.txt"), "content").unwrap();
382
383        let root_fd = open_dir(&root);
384        let path = SafePath::resolve(root_fd, root.clone(), Path::new("test.txt"), 10).unwrap();
385
386        assert_eq!(path.as_path(), root.join("test.txt"));
387    }
388
389    #[test]
390    fn test_resolve_nonexistent() {
391        let dir = tempdir().unwrap();
392        let root = std::fs::canonicalize(dir.path()).unwrap();
393
394        let root_fd = open_dir(&root);
395        let path = SafePath::resolve(root_fd, root.clone(), Path::new("newfile.txt"), 10).unwrap();
396
397        assert_eq!(path.as_path(), root.join("newfile.txt"));
398    }
399
400    #[test]
401    fn test_path_traversal_blocked() {
402        let dir = tempdir().unwrap();
403        let root = std::fs::canonicalize(dir.path()).unwrap();
404
405        let root_fd = open_dir(&root);
406        let result = SafePath::resolve(root_fd, root, Path::new("../../../etc/passwd"), 10);
407
408        assert!(matches!(result, Err(SecurityError::PathEscape(_))));
409    }
410
411    #[test]
412    fn test_symlink_within_sandbox() {
413        let dir = tempdir().unwrap();
414        let root = std::fs::canonicalize(dir.path()).unwrap();
415
416        fs::write(root.join("target.txt"), "content").unwrap();
417        std::os::unix::fs::symlink("target.txt", root.join("link.txt")).unwrap();
418
419        let root_fd = open_dir(&root);
420        let path = SafePath::resolve(root_fd, root.clone(), Path::new("link.txt"), 10).unwrap();
421
422        assert_eq!(path.as_path(), root.join("target.txt"));
423    }
424
425    #[test]
426    fn test_symlink_depth_limit() {
427        let dir = tempdir().unwrap();
428        let root = std::fs::canonicalize(dir.path()).unwrap();
429
430        for i in 0..15 {
431            let target = if i == 14 {
432                "final.txt".to_string()
433            } else {
434                format!("link{}.txt", i + 1)
435            };
436            std::os::unix::fs::symlink(&target, root.join(format!("link{}.txt", i))).unwrap();
437        }
438        fs::write(root.join("final.txt"), "content").unwrap();
439
440        let root_fd = open_dir(&root);
441        let result = SafePath::resolve(root_fd, root, Path::new("link0.txt"), 10);
442
443        assert!(matches!(
444            result,
445            Err(SecurityError::SymlinkDepthExceeded { .. })
446        ));
447    }
448}