Skip to main content

gritty/
security.rs

1use std::io;
2use std::os::fd::{FromRawFd, OwnedFd, RawFd};
3use std::os::unix::fs::{MetadataExt, PermissionsExt};
4use std::path::Path;
5
6const MAX_WINSIZE: u16 = 10_000;
7
8/// Create a directory hierarchy with mode 0700, validating ownership of existing components.
9/// Trusted system roots (`/`, `/tmp`, `/run`, `$XDG_RUNTIME_DIR`) are accepted without
10/// ownership checks. All other existing directories must be owned by the current user
11/// and must not be symlinks.
12pub fn secure_create_dir_all(path: &Path) -> io::Result<()> {
13    if path.exists() {
14        if is_trusted_root(path) {
15            return Ok(());
16        }
17        return validate_dir(path);
18    }
19
20    if let Some(parent) = path.parent() {
21        secure_create_dir_all(parent)?;
22    }
23
24    std::fs::create_dir(path)?;
25    std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o700))?;
26    Ok(())
27}
28
29/// Bind a `UnixListener` with TOCTOU-safe stale socket handling and 0600 permissions.
30///
31/// On `AddrInUse`, probes the existing socket: if it responds to a connect, returns
32/// an error (socket is alive). Otherwise, removes the stale socket and retries.
33pub fn bind_unix_listener(path: &Path) -> io::Result<tokio::net::UnixListener> {
34    match tokio::net::UnixListener::bind(path) {
35        Ok(listener) => {
36            set_socket_permissions(path)?;
37            Ok(listener)
38        }
39        Err(e) if e.kind() == io::ErrorKind::AddrInUse => {
40            match std::os::unix::net::UnixStream::connect(path) {
41                Ok(_) => Err(io::Error::new(
42                    io::ErrorKind::AddrInUse,
43                    format!("{} is already in use by a running process", path.display()),
44                )),
45                Err(_) => {
46                    std::fs::remove_file(path)?;
47                    let listener = tokio::net::UnixListener::bind(path)?;
48                    set_socket_permissions(path)?;
49                    Ok(listener)
50                }
51            }
52        }
53        Err(e) => Err(e),
54    }
55}
56
57/// Verify that the peer on a Unix stream has the same UID as the current process.
58pub fn verify_peer_uid(stream: &tokio::net::UnixStream) -> io::Result<()> {
59    let cred = stream.peer_cred()?;
60    let my_uid = unsafe { libc::getuid() };
61    if cred.uid() != my_uid {
62        return Err(io::Error::new(
63            io::ErrorKind::PermissionDenied,
64            format!("rejecting connection from uid {} (expected {my_uid})", cred.uid()),
65        ));
66    }
67    Ok(())
68}
69
70/// `dup(2)` that returns an `OwnedFd` or an error (instead of silently returning -1).
71pub fn checked_dup(fd: RawFd) -> io::Result<OwnedFd> {
72    let new_fd = unsafe { libc::dup(fd) };
73    if new_fd == -1 {
74        return Err(io::Error::last_os_error());
75    }
76    Ok(unsafe { OwnedFd::from_raw_fd(new_fd) })
77}
78
79/// Clamp window-size values to a sane range, preventing zero-sized or absurdly large values.
80pub fn clamp_winsize(cols: u16, rows: u16) -> (u16, u16) {
81    (cols.clamp(1, MAX_WINSIZE), rows.clamp(1, MAX_WINSIZE))
82}
83
84fn set_socket_permissions(path: &Path) -> io::Result<()> {
85    std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))
86}
87
88fn is_trusted_root(path: &Path) -> bool {
89    if matches!(path.to_str(), Some("/" | "/tmp" | "/run")) {
90        return true;
91    }
92    std::env::var("XDG_RUNTIME_DIR").ok().is_some_and(|xdg| path == Path::new(&xdg))
93}
94
95fn validate_dir(path: &Path) -> io::Result<()> {
96    let meta = std::fs::symlink_metadata(path)?;
97
98    if meta.file_type().is_symlink() {
99        return Err(io::Error::new(
100            io::ErrorKind::InvalidInput,
101            format!("refusing to use symlink at {}", path.display()),
102        ));
103    }
104    if !meta.is_dir() {
105        return Err(io::Error::new(
106            io::ErrorKind::InvalidInput,
107            format!("{} is not a directory", path.display()),
108        ));
109    }
110
111    let uid = unsafe { libc::getuid() };
112    if meta.uid() != uid {
113        return Err(io::Error::new(
114            io::ErrorKind::PermissionDenied,
115            format!(
116                "{} is owned by uid {}, expected uid {uid}; \
117                 set $XDG_RUNTIME_DIR or use --ctl-socket",
118                path.display(),
119                meta.uid()
120            ),
121        ));
122    }
123
124    Ok(())
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn clamp_winsize_zeros_to_minimum() {
133        assert_eq!(clamp_winsize(0, 0), (1, 1));
134    }
135
136    #[test]
137    fn clamp_winsize_normal_passthrough() {
138        assert_eq!(clamp_winsize(80, 24), (80, 24));
139    }
140
141    #[test]
142    fn clamp_winsize_max_boundary() {
143        assert_eq!(clamp_winsize(10_000, 10_000), (10_000, 10_000));
144    }
145
146    #[test]
147    fn clamp_winsize_over_max_clamped() {
148        assert_eq!(clamp_winsize(10_001, 10_001), (10_000, 10_000));
149    }
150
151    #[test]
152    fn clamp_winsize_extreme_values() {
153        assert_eq!(clamp_winsize(u16::MAX, u16::MAX), (10_000, 10_000));
154    }
155
156    #[test]
157    fn clamp_winsize_asymmetric() {
158        assert_eq!(clamp_winsize(0, 80), (1, 80));
159        assert_eq!(clamp_winsize(20_000, 5), (10_000, 5));
160    }
161
162    // --- secure_create_dir_all ---
163
164    #[test]
165    fn secure_create_dir_all_fresh_hierarchy() {
166        let tmp = tempfile::tempdir().unwrap();
167        let deep = tmp.path().join("a").join("b").join("c");
168        secure_create_dir_all(&deep).unwrap();
169        assert!(deep.is_dir());
170        let mode = std::fs::metadata(&deep).unwrap().permissions().mode() & 0o777;
171        assert_eq!(mode, 0o700);
172    }
173
174    #[test]
175    fn secure_create_dir_all_idempotent() {
176        let tmp = tempfile::tempdir().unwrap();
177        let dir = tmp.path().join("mydir");
178        secure_create_dir_all(&dir).unwrap();
179        secure_create_dir_all(&dir).unwrap(); // second call succeeds
180    }
181
182    #[test]
183    fn secure_create_dir_all_rejects_symlink() {
184        let tmp = tempfile::tempdir().unwrap();
185        let real = tmp.path().join("real");
186        std::fs::create_dir(&real).unwrap();
187        let link = tmp.path().join("link");
188        std::os::unix::fs::symlink(&real, &link).unwrap();
189        let err = secure_create_dir_all(&link).unwrap_err();
190        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
191    }
192
193    #[test]
194    fn secure_create_dir_all_rejects_regular_file() {
195        let tmp = tempfile::tempdir().unwrap();
196        let file = tmp.path().join("not_a_dir");
197        std::fs::write(&file, b"").unwrap();
198        let err = secure_create_dir_all(&file).unwrap_err();
199        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
200    }
201
202    #[test]
203    fn secure_create_dir_all_trusted_roots() {
204        // These should succeed without ownership checks
205        secure_create_dir_all(Path::new("/")).unwrap();
206        secure_create_dir_all(Path::new("/tmp")).unwrap();
207    }
208
209    // --- bind_unix_listener ---
210
211    #[tokio::test]
212    async fn bind_unix_listener_fresh() {
213        let tmp = tempfile::tempdir().unwrap();
214        let sock = tmp.path().join("test.sock");
215        let _listener = bind_unix_listener(&sock).unwrap();
216        let mode = std::fs::metadata(&sock).unwrap().permissions().mode() & 0o777;
217        assert_eq!(mode, 0o600);
218    }
219
220    #[tokio::test]
221    async fn bind_unix_listener_stale_socket() {
222        let tmp = tempfile::tempdir().unwrap();
223        let sock = tmp.path().join("stale.sock");
224        // Create a stale socket file using UnixDatagram (never calls listen()).
225        // This avoids a macOS kernel race where connect() briefly succeeds on a
226        // just-closed listening socket.
227        drop(std::os::unix::net::UnixDatagram::bind(&sock).unwrap());
228        assert!(sock.exists());
229        // Re-bind should clean up stale socket and succeed
230        let _listener = bind_unix_listener(&sock).unwrap();
231    }
232
233    #[tokio::test]
234    async fn bind_unix_listener_live_socket_errors() {
235        let tmp = tempfile::tempdir().unwrap();
236        let sock = tmp.path().join("live.sock");
237        let _listener = bind_unix_listener(&sock).unwrap();
238        // Try to bind again while listener is alive
239        let err = bind_unix_listener(&sock).unwrap_err();
240        assert_eq!(err.kind(), io::ErrorKind::AddrInUse);
241    }
242
243    #[tokio::test]
244    async fn bind_unix_listener_nonexistent_dir() {
245        let err = bind_unix_listener(Path::new("/no/such/dir/test.sock")).unwrap_err();
246        assert!(err.kind() == io::ErrorKind::NotFound || err.kind() == io::ErrorKind::Other);
247    }
248
249    // --- verify_peer_uid ---
250
251    #[tokio::test]
252    async fn verify_peer_uid_same_process() {
253        let (a, _b) = tokio::net::UnixStream::pair().unwrap();
254        verify_peer_uid(&a).unwrap();
255    }
256
257    // --- checked_dup ---
258
259    #[test]
260    fn checked_dup_stdout() {
261        use std::os::fd::AsRawFd;
262        let fd = checked_dup(1).unwrap();
263        assert!(fd.as_raw_fd() > 2);
264    }
265
266    #[test]
267    fn checked_dup_invalid_fd() {
268        assert!(checked_dup(-1).is_err());
269    }
270}