1use nix::{
18 fcntl::{F_SETFD, FdFlag, fcntl},
19 libc,
20};
21use std::{
22 collections::HashMap,
23 fs::{canonicalize, read_dir},
24 os::fd::{FromRawFd, OwnedFd, RawFd},
25 sync::{Mutex, OnceLock},
26};
27use thiserror::Error;
28
29static INHERITED_FDS: OnceLock<Mutex<HashMap<RawFd, Option<OwnedFd>>>> = OnceLock::new();
30
31#[derive(Debug, PartialEq, Error)]
33pub enum InheritedFdError {
34 #[error("init_inherited_fds() not called")]
36 NotInitialized,
37
38 #[error("Ownership of FD {0} is already taken")]
40 OwnershipTaken(RawFd),
41
42 #[error("FD {0} is either invalid file descriptor or not an inherited one")]
44 FileDescriptorNotInherited(RawFd),
45}
46
47pub unsafe fn init_inherited_fds() -> Result<(), std::io::Error> {
57 let mut fds = HashMap::new();
58
59 let fd_path = canonicalize("/proc/self/fd")?;
60
61 for entry in read_dir(&fd_path)? {
62 let entry = entry?;
63
64 let file_name = entry.file_name();
66 let raw_fd = file_name.to_str().unwrap().parse::<RawFd>().unwrap();
67
68 if [libc::STDIN_FILENO, libc::STDOUT_FILENO, libc::STDERR_FILENO].contains(&raw_fd) {
70 continue;
71 }
72
73 if entry.path().read_link()? == fd_path {
77 continue;
78 }
79
80 let owned_fd = unsafe { OwnedFd::from_raw_fd(raw_fd) };
85
86 fcntl(&owned_fd, F_SETFD(FdFlag::FD_CLOEXEC))?;
87 fds.insert(raw_fd, Some(owned_fd));
88 }
89
90 INHERITED_FDS
91 .set(Mutex::new(fds))
92 .or(Err(std::io::Error::other(
93 "Inherited fds were already initialized",
94 )))
95}
96
97pub fn take_fd_ownership(raw_fd: RawFd) -> Result<OwnedFd, InheritedFdError> {
104 let mut fds = INHERITED_FDS
105 .get()
106 .ok_or(InheritedFdError::NotInitialized)?
107 .lock()
108 .unwrap();
109
110 if let Some(value) = fds.get_mut(&raw_fd) {
111 if let Some(owned_fd) = value.take() {
112 Ok(owned_fd)
113 } else {
114 Err(InheritedFdError::OwnershipTaken(raw_fd))
115 }
116 } else {
117 Err(InheritedFdError::FileDescriptorNotInherited(raw_fd))
118 }
119}
120
121#[cfg(test)]
122mod test {
123 use super::*;
124 use nix::unistd::close;
125 use std::{
126 io,
127 os::fd::{AsRawFd, IntoRawFd},
128 };
129 use tempfile::tempfile;
130
131 struct Fixture {
132 fds: Vec<RawFd>,
133 }
134
135 impl Fixture {
136 fn setup(num_fds: usize) -> Result<Self, io::Error> {
137 let mut fds = Vec::new();
138 for _ in 0..num_fds {
139 fds.push(tempfile()?.into_raw_fd());
140 }
141 Ok(Fixture { fds })
142 }
143
144 fn open_new_file(&mut self) -> Result<RawFd, io::Error> {
145 let raw_fd = tempfile()?.into_raw_fd();
146 self.fds.push(raw_fd);
147 Ok(raw_fd)
148 }
149 }
150
151 impl Drop for Fixture {
152 fn drop(&mut self) {
153 self.fds.iter().for_each(|fd| {
154 let _ = close(*fd);
155 });
156 }
157 }
158
159 fn is_fd_opened(raw_fd: RawFd) -> bool {
160 unsafe { libc::fcntl(raw_fd, libc::F_GETFD) != -1 }
161 }
162
163 #[test]
164 fn happy_case() {
165 let fixture = Fixture::setup(2).unwrap();
166 let f0 = fixture.fds[0];
167 let f1 = fixture.fds[1];
168
169 unsafe {
171 init_inherited_fds().unwrap();
172 }
173
174 let f0_owned = take_fd_ownership(f0).unwrap();
175 let f1_owned = take_fd_ownership(f1).unwrap();
176 assert_eq!(f0, f0_owned.as_raw_fd());
177 assert_eq!(f1, f1_owned.as_raw_fd());
178
179 drop(f0_owned);
180 drop(f1_owned);
181 assert!(!is_fd_opened(f0));
182 assert!(!is_fd_opened(f1));
183 }
184
185 #[test]
186 fn access_non_inherited_fd() {
187 let mut fixture = Fixture::setup(2).unwrap();
188
189 unsafe {
191 init_inherited_fds().unwrap();
192 }
193
194 let f = fixture.open_new_file().unwrap();
195 assert_eq!(
196 take_fd_ownership(f).err(),
197 Some(InheritedFdError::FileDescriptorNotInherited(f))
198 );
199 }
200
201 #[test]
202 fn call_init_inherited_fds_multiple_times() {
203 let _ = Fixture::setup(2).unwrap();
204
205 unsafe {
207 init_inherited_fds().unwrap();
208 }
209
210 let res = unsafe { init_inherited_fds() };
212 assert!(res.is_err());
213 }
214
215 #[test]
216 fn access_without_init_inherited_fds() {
217 let fixture = Fixture::setup(2).unwrap();
218
219 let f = fixture.fds[0];
220 assert_eq!(
221 take_fd_ownership(f).err(),
222 Some(InheritedFdError::NotInitialized)
223 );
224 }
225
226 #[test]
227 fn double_ownership() {
228 let fixture = Fixture::setup(2).unwrap();
229 let f = fixture.fds[0];
230
231 unsafe {
233 init_inherited_fds().unwrap();
234 }
235
236 let f_owned = take_fd_ownership(f).unwrap();
237 let f_double_owned = take_fd_ownership(f);
238 assert_eq!(
239 f_double_owned.err(),
240 Some(InheritedFdError::OwnershipTaken(f)),
241 );
242
243 drop(f_owned);
246 }
247
248 #[test]
249 fn take_drop_retake() {
250 let fixture = Fixture::setup(2).unwrap();
251 let f = fixture.fds[0];
252
253 unsafe {
255 init_inherited_fds().unwrap();
256 }
257
258 let f_owned = take_fd_ownership(f).unwrap();
259 drop(f_owned);
260
261 let f_double_owned = take_fd_ownership(f);
262 assert_eq!(
263 f_double_owned.err(),
264 Some(InheritedFdError::OwnershipTaken(f)),
265 );
266 }
267
268 #[test]
269 fn cloexec() {
270 let fixture = Fixture::setup(2).unwrap();
271 let f = fixture.fds[0];
272
273 let res = unsafe { libc::fcntl(f.as_raw_fd(), libc::F_SETFD, 0) };
274 assert_ne!(res, -1);
275
276 unsafe {
278 init_inherited_fds().unwrap();
279 }
280
281 let flags = unsafe { libc::fcntl(f.as_raw_fd(), libc::F_GETFD) };
283 assert_ne!(flags, -1);
284 assert_eq!(flags, FdFlag::FD_CLOEXEC.bits());
286 }
287}