Skip to main content

gritty/
security.rs

1use std::io;
2use std::os::fd::{FromRawFd, OwnedFd, RawFd};
3use std::os::unix::fs::{MetadataExt, PermissionsExt};
4use std::path::Path;
5
6const MAX_WINSIZE: u16 = 10_000;
7
8/// Create a directory hierarchy with mode 0700, validating ownership of existing components.
9/// Trusted system roots (`/`, `/tmp`, `/run`, `$XDG_RUNTIME_DIR`) are accepted without
10/// ownership checks. All other existing directories must be owned by the current user
11/// and must not be symlinks.
12///
13/// Uses create-then-validate instead of check-then-create to avoid TOCTOU races.
14pub fn secure_create_dir_all(path: &Path) -> io::Result<()> {
15    if is_trusted_root(path) {
16        return Ok(());
17    }
18
19    if let Some(parent) = path.parent() {
20        secure_create_dir_all(parent)?;
21    }
22
23    // Try to create; if it already exists, validate ownership/type.
24    // The umask (0o077 set by daemon) ensures 0o700 mode on creation.
25    match std::fs::create_dir(path) {
26        Ok(()) => {
27            // Explicitly set permissions in case umask wasn't set (e.g. client-side calls).
28            std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o700))?;
29            Ok(())
30        }
31        Err(e) if e.kind() == io::ErrorKind::AlreadyExists => validate_dir(path),
32        Err(e) => Err(e),
33    }
34}
35
36/// Bind a `UnixListener` with TOCTOU-safe stale socket handling and 0600 permissions.
37///
38/// On `AddrInUse`, probes the existing socket: if it responds to a connect, returns
39/// an error (socket is alive). Otherwise, removes the stale socket and retries.
40pub fn bind_unix_listener(path: &Path) -> io::Result<tokio::net::UnixListener> {
41    match tokio::net::UnixListener::bind(path) {
42        Ok(listener) => {
43            set_socket_permissions(path)?;
44            Ok(listener)
45        }
46        Err(e) if e.kind() == io::ErrorKind::AddrInUse => {
47            match std::os::unix::net::UnixStream::connect(path) {
48                Ok(_) => Err(io::Error::new(
49                    io::ErrorKind::AddrInUse,
50                    format!("{} is already in use by a running process", path.display()),
51                )),
52                Err(_) => {
53                    std::fs::remove_file(path)?;
54                    let listener = tokio::net::UnixListener::bind(path)?;
55                    set_socket_permissions(path)?;
56                    Ok(listener)
57                }
58            }
59        }
60        Err(e) => Err(e),
61    }
62}
63
64/// Verify that the peer on a Unix stream has the same UID as the current process.
65pub fn verify_peer_uid(stream: &tokio::net::UnixStream) -> io::Result<()> {
66    let cred = stream.peer_cred()?;
67    let my_uid = unsafe { libc::getuid() };
68    if cred.uid() != my_uid {
69        return Err(io::Error::new(
70            io::ErrorKind::PermissionDenied,
71            format!("rejecting connection from uid {} (expected {my_uid})", cred.uid()),
72        ));
73    }
74    Ok(())
75}
76
77/// `dup(2)` that returns an `OwnedFd` or an error (instead of silently returning -1).
78pub fn checked_dup(fd: RawFd) -> io::Result<OwnedFd> {
79    let new_fd = unsafe { libc::dup(fd) };
80    if new_fd == -1 {
81        return Err(io::Error::last_os_error());
82    }
83    Ok(unsafe { OwnedFd::from_raw_fd(new_fd) })
84}
85
86/// Clamp window-size values to a sane range, preventing zero-sized or absurdly large values.
87pub fn clamp_winsize(cols: u16, rows: u16) -> (u16, u16) {
88    (cols.clamp(1, MAX_WINSIZE), rows.clamp(1, MAX_WINSIZE))
89}
90
91fn set_socket_permissions(path: &Path) -> io::Result<()> {
92    std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))
93}
94
95fn is_trusted_root(path: &Path) -> bool {
96    if matches!(path.to_str(), Some("/" | "/tmp" | "/run")) {
97        return true;
98    }
99    std::env::var("XDG_RUNTIME_DIR").ok().is_some_and(|xdg| path == Path::new(&xdg))
100}
101
102fn validate_dir(path: &Path) -> io::Result<()> {
103    let meta = std::fs::symlink_metadata(path)?;
104
105    // Root-owned entries are system-managed (e.g. /var -> /private/var on macOS).
106    // Just verify the path resolves to a directory.
107    if meta.uid() == 0 {
108        if !path.is_dir() {
109            return Err(io::Error::new(
110                io::ErrorKind::InvalidInput,
111                format!("{} does not resolve to a directory", path.display()),
112            ));
113        }
114        return Ok(());
115    }
116
117    if meta.file_type().is_symlink() {
118        return Err(io::Error::new(
119            io::ErrorKind::InvalidInput,
120            format!("refusing to use symlink at {}", path.display()),
121        ));
122    }
123    if !meta.is_dir() {
124        return Err(io::Error::new(
125            io::ErrorKind::InvalidInput,
126            format!("{} is not a directory", path.display()),
127        ));
128    }
129
130    let uid = unsafe { libc::getuid() };
131    if meta.uid() != uid {
132        return Err(io::Error::new(
133            io::ErrorKind::PermissionDenied,
134            format!(
135                "{} is owned by uid {}, expected uid {uid}; \
136                 set $XDG_RUNTIME_DIR or use --ctl-socket",
137                path.display(),
138                meta.uid()
139            ),
140        ));
141    }
142
143    Ok(())
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn clamp_winsize_zeros_to_minimum() {
152        assert_eq!(clamp_winsize(0, 0), (1, 1));
153    }
154
155    #[test]
156    fn clamp_winsize_normal_passthrough() {
157        assert_eq!(clamp_winsize(80, 24), (80, 24));
158    }
159
160    #[test]
161    fn clamp_winsize_max_boundary() {
162        assert_eq!(clamp_winsize(10_000, 10_000), (10_000, 10_000));
163    }
164
165    #[test]
166    fn clamp_winsize_over_max_clamped() {
167        assert_eq!(clamp_winsize(10_001, 10_001), (10_000, 10_000));
168    }
169
170    #[test]
171    fn clamp_winsize_extreme_values() {
172        assert_eq!(clamp_winsize(u16::MAX, u16::MAX), (10_000, 10_000));
173    }
174
175    #[test]
176    fn clamp_winsize_asymmetric() {
177        assert_eq!(clamp_winsize(0, 80), (1, 80));
178        assert_eq!(clamp_winsize(20_000, 5), (10_000, 5));
179    }
180
181    // --- secure_create_dir_all ---
182
183    #[test]
184    fn secure_create_dir_all_fresh_hierarchy() {
185        let tmp = tempfile::tempdir().unwrap();
186        let deep = tmp.path().join("a").join("b").join("c");
187        secure_create_dir_all(&deep).unwrap();
188        assert!(deep.is_dir());
189        let mode = std::fs::metadata(&deep).unwrap().permissions().mode() & 0o777;
190        assert_eq!(mode, 0o700);
191    }
192
193    #[test]
194    fn secure_create_dir_all_idempotent() {
195        let tmp = tempfile::tempdir().unwrap();
196        let dir = tmp.path().join("mydir");
197        secure_create_dir_all(&dir).unwrap();
198        secure_create_dir_all(&dir).unwrap(); // second call succeeds
199    }
200
201    #[test]
202    fn secure_create_dir_all_rejects_symlink() {
203        let tmp = tempfile::tempdir().unwrap();
204        let real = tmp.path().join("real");
205        std::fs::create_dir(&real).unwrap();
206        let link = tmp.path().join("link");
207        std::os::unix::fs::symlink(&real, &link).unwrap();
208        let err = secure_create_dir_all(&link).unwrap_err();
209        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
210    }
211
212    #[test]
213    fn secure_create_dir_all_rejects_regular_file() {
214        let tmp = tempfile::tempdir().unwrap();
215        let file = tmp.path().join("not_a_dir");
216        std::fs::write(&file, b"").unwrap();
217        let err = secure_create_dir_all(&file).unwrap_err();
218        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
219    }
220
221    #[test]
222    fn secure_create_dir_all_trusted_roots() {
223        // These should succeed without ownership checks
224        secure_create_dir_all(Path::new("/")).unwrap();
225        secure_create_dir_all(Path::new("/tmp")).unwrap();
226    }
227
228    // --- bind_unix_listener ---
229
230    #[tokio::test]
231    async fn bind_unix_listener_fresh() {
232        let tmp = tempfile::tempdir().unwrap();
233        let sock = tmp.path().join("test.sock");
234        let _listener = bind_unix_listener(&sock).unwrap();
235        let mode = std::fs::metadata(&sock).unwrap().permissions().mode() & 0o777;
236        assert_eq!(mode, 0o600);
237    }
238
239    #[tokio::test]
240    async fn bind_unix_listener_stale_socket() {
241        let tmp = tempfile::tempdir().unwrap();
242        let sock = tmp.path().join("stale.sock");
243        // Create a stale socket file using UnixDatagram (never calls listen()).
244        // This avoids a macOS kernel race where connect() briefly succeeds on a
245        // just-closed listening socket.
246        drop(std::os::unix::net::UnixDatagram::bind(&sock).unwrap());
247        assert!(sock.exists());
248        // Re-bind should clean up stale socket and succeed
249        let _listener = bind_unix_listener(&sock).unwrap();
250    }
251
252    #[tokio::test]
253    async fn bind_unix_listener_live_socket_errors() {
254        let tmp = tempfile::tempdir().unwrap();
255        let sock = tmp.path().join("live.sock");
256        let _listener = bind_unix_listener(&sock).unwrap();
257        // Try to bind again while listener is alive
258        let err = bind_unix_listener(&sock).unwrap_err();
259        assert_eq!(err.kind(), io::ErrorKind::AddrInUse);
260    }
261
262    #[tokio::test]
263    async fn bind_unix_listener_nonexistent_dir() {
264        let err = bind_unix_listener(Path::new("/no/such/dir/test.sock")).unwrap_err();
265        assert!(err.kind() == io::ErrorKind::NotFound || err.kind() == io::ErrorKind::Other);
266    }
267
268    // --- verify_peer_uid ---
269
270    #[tokio::test]
271    async fn verify_peer_uid_same_process() {
272        let (a, _b) = tokio::net::UnixStream::pair().unwrap();
273        verify_peer_uid(&a).unwrap();
274    }
275
276    // --- checked_dup ---
277
278    #[test]
279    fn checked_dup_stdout() {
280        use std::os::fd::AsRawFd;
281        let fd = checked_dup(1).unwrap();
282        assert!(fd.as_raw_fd() > 2);
283    }
284
285    #[test]
286    fn checked_dup_invalid_fd() {
287        assert!(checked_dup(-1).is_err());
288    }
289}