1use std::io;
2use std::os::fd::{FromRawFd, OwnedFd, RawFd};
3use std::os::unix::fs::{MetadataExt, PermissionsExt};
4use std::path::Path;
5
6const MAX_WINSIZE: u16 = 10_000;
7
8pub fn secure_create_dir_all(path: &Path) -> io::Result<()> {
13 if path.exists() {
14 if is_trusted_root(path) {
15 return Ok(());
16 }
17 return validate_dir(path);
18 }
19
20 if let Some(parent) = path.parent() {
21 secure_create_dir_all(parent)?;
22 }
23
24 std::fs::create_dir(path)?;
25 std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o700))?;
26 Ok(())
27}
28
29pub fn bind_unix_listener(path: &Path) -> io::Result<tokio::net::UnixListener> {
34 match tokio::net::UnixListener::bind(path) {
35 Ok(listener) => {
36 set_socket_permissions(path)?;
37 Ok(listener)
38 }
39 Err(e) if e.kind() == io::ErrorKind::AddrInUse => {
40 match std::os::unix::net::UnixStream::connect(path) {
41 Ok(_) => Err(io::Error::new(
42 io::ErrorKind::AddrInUse,
43 format!("{} is already in use by a running process", path.display()),
44 )),
45 Err(_) => {
46 std::fs::remove_file(path)?;
47 let listener = tokio::net::UnixListener::bind(path)?;
48 set_socket_permissions(path)?;
49 Ok(listener)
50 }
51 }
52 }
53 Err(e) => Err(e),
54 }
55}
56
57pub fn verify_peer_uid(stream: &tokio::net::UnixStream) -> io::Result<()> {
59 let cred = stream.peer_cred()?;
60 let my_uid = unsafe { libc::getuid() };
61 if cred.uid() != my_uid {
62 return Err(io::Error::new(
63 io::ErrorKind::PermissionDenied,
64 format!("rejecting connection from uid {} (expected {my_uid})", cred.uid()),
65 ));
66 }
67 Ok(())
68}
69
70pub fn checked_dup(fd: RawFd) -> io::Result<OwnedFd> {
72 let new_fd = unsafe { libc::dup(fd) };
73 if new_fd == -1 {
74 return Err(io::Error::last_os_error());
75 }
76 Ok(unsafe { OwnedFd::from_raw_fd(new_fd) })
77}
78
79pub fn clamp_winsize(cols: u16, rows: u16) -> (u16, u16) {
81 (cols.clamp(1, MAX_WINSIZE), rows.clamp(1, MAX_WINSIZE))
82}
83
84fn set_socket_permissions(path: &Path) -> io::Result<()> {
85 std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))
86}
87
88fn is_trusted_root(path: &Path) -> bool {
89 if matches!(path.to_str(), Some("/" | "/tmp" | "/run")) {
90 return true;
91 }
92 std::env::var("XDG_RUNTIME_DIR").ok().is_some_and(|xdg| path == Path::new(&xdg))
93}
94
95fn validate_dir(path: &Path) -> io::Result<()> {
96 let meta = std::fs::symlink_metadata(path)?;
97
98 if meta.file_type().is_symlink() {
99 return Err(io::Error::new(
100 io::ErrorKind::InvalidInput,
101 format!("refusing to use symlink at {}", path.display()),
102 ));
103 }
104 if !meta.is_dir() {
105 return Err(io::Error::new(
106 io::ErrorKind::InvalidInput,
107 format!("{} is not a directory", path.display()),
108 ));
109 }
110
111 let uid = unsafe { libc::getuid() };
112 if meta.uid() != uid {
113 return Err(io::Error::new(
114 io::ErrorKind::PermissionDenied,
115 format!(
116 "{} is owned by uid {}, expected uid {uid}; \
117 set $XDG_RUNTIME_DIR or use --ctl-socket",
118 path.display(),
119 meta.uid()
120 ),
121 ));
122 }
123
124 Ok(())
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn clamp_winsize_zeros_to_minimum() {
133 assert_eq!(clamp_winsize(0, 0), (1, 1));
134 }
135
136 #[test]
137 fn clamp_winsize_normal_passthrough() {
138 assert_eq!(clamp_winsize(80, 24), (80, 24));
139 }
140
141 #[test]
142 fn clamp_winsize_max_boundary() {
143 assert_eq!(clamp_winsize(10_000, 10_000), (10_000, 10_000));
144 }
145
146 #[test]
147 fn clamp_winsize_over_max_clamped() {
148 assert_eq!(clamp_winsize(10_001, 10_001), (10_000, 10_000));
149 }
150
151 #[test]
152 fn clamp_winsize_extreme_values() {
153 assert_eq!(clamp_winsize(u16::MAX, u16::MAX), (10_000, 10_000));
154 }
155
156 #[test]
157 fn clamp_winsize_asymmetric() {
158 assert_eq!(clamp_winsize(0, 80), (1, 80));
159 assert_eq!(clamp_winsize(20_000, 5), (10_000, 5));
160 }
161
162 #[test]
165 fn secure_create_dir_all_fresh_hierarchy() {
166 let tmp = tempfile::tempdir().unwrap();
167 let deep = tmp.path().join("a").join("b").join("c");
168 secure_create_dir_all(&deep).unwrap();
169 assert!(deep.is_dir());
170 let mode = std::fs::metadata(&deep).unwrap().permissions().mode() & 0o777;
171 assert_eq!(mode, 0o700);
172 }
173
174 #[test]
175 fn secure_create_dir_all_idempotent() {
176 let tmp = tempfile::tempdir().unwrap();
177 let dir = tmp.path().join("mydir");
178 secure_create_dir_all(&dir).unwrap();
179 secure_create_dir_all(&dir).unwrap(); }
181
182 #[test]
183 fn secure_create_dir_all_rejects_symlink() {
184 let tmp = tempfile::tempdir().unwrap();
185 let real = tmp.path().join("real");
186 std::fs::create_dir(&real).unwrap();
187 let link = tmp.path().join("link");
188 std::os::unix::fs::symlink(&real, &link).unwrap();
189 let err = secure_create_dir_all(&link).unwrap_err();
190 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
191 }
192
193 #[test]
194 fn secure_create_dir_all_rejects_regular_file() {
195 let tmp = tempfile::tempdir().unwrap();
196 let file = tmp.path().join("not_a_dir");
197 std::fs::write(&file, b"").unwrap();
198 let err = secure_create_dir_all(&file).unwrap_err();
199 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
200 }
201
202 #[test]
203 fn secure_create_dir_all_trusted_roots() {
204 secure_create_dir_all(Path::new("/")).unwrap();
206 secure_create_dir_all(Path::new("/tmp")).unwrap();
207 }
208
209 #[tokio::test]
212 async fn bind_unix_listener_fresh() {
213 let tmp = tempfile::tempdir().unwrap();
214 let sock = tmp.path().join("test.sock");
215 let _listener = bind_unix_listener(&sock).unwrap();
216 let mode = std::fs::metadata(&sock).unwrap().permissions().mode() & 0o777;
217 assert_eq!(mode, 0o600);
218 }
219
220 #[tokio::test]
221 async fn bind_unix_listener_stale_socket() {
222 let tmp = tempfile::tempdir().unwrap();
223 let sock = tmp.path().join("stale.sock");
224 drop(std::os::unix::net::UnixDatagram::bind(&sock).unwrap());
228 assert!(sock.exists());
229 let _listener = bind_unix_listener(&sock).unwrap();
231 }
232
233 #[tokio::test]
234 async fn bind_unix_listener_live_socket_errors() {
235 let tmp = tempfile::tempdir().unwrap();
236 let sock = tmp.path().join("live.sock");
237 let _listener = bind_unix_listener(&sock).unwrap();
238 let err = bind_unix_listener(&sock).unwrap_err();
240 assert_eq!(err.kind(), io::ErrorKind::AddrInUse);
241 }
242
243 #[tokio::test]
244 async fn bind_unix_listener_nonexistent_dir() {
245 let err = bind_unix_listener(Path::new("/no/such/dir/test.sock")).unwrap_err();
246 assert!(err.kind() == io::ErrorKind::NotFound || err.kind() == io::ErrorKind::Other);
247 }
248
249 #[tokio::test]
252 async fn verify_peer_uid_same_process() {
253 let (a, _b) = tokio::net::UnixStream::pair().unwrap();
254 verify_peer_uid(&a).unwrap();
255 }
256
257 #[test]
260 fn checked_dup_stdout() {
261 use std::os::fd::AsRawFd;
262 let fd = checked_dup(1).unwrap();
263 assert!(fd.as_raw_fd() > 2);
264 }
265
266 #[test]
267 fn checked_dup_invalid_fd() {
268 assert!(checked_dup(-1).is_err());
269 }
270}