1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6 future::{Future, poll_fn},
7 mem::ManuallyDrop,
8 ops::Deref,
9 panic::RefUnwindSafe,
10 sync::{
11 Arc,
12 atomic::{AtomicBool, Ordering},
13 },
14 task::Poll,
15};
16
17use futures_util::task::AtomicWaker;
18
19use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
20
21#[derive(Debug)]
22struct Inner<T> {
23 fd: T,
24 waits: AtomicBool,
26 waker: AtomicWaker,
27}
28
29impl<T> RefUnwindSafe for Inner<T> {}
30
31#[derive(Debug)]
34pub struct SharedFd<T>(Arc<Inner<T>>);
35
36impl<T: AsFd> SharedFd<T> {
37 pub fn new(fd: T) -> Self {
39 unsafe { Self::new_unchecked(fd) }
40 }
41}
42
43impl<T> SharedFd<T> {
44 pub unsafe fn new_unchecked(fd: T) -> Self {
49 Self(Arc::new(Inner {
50 fd,
51 waits: AtomicBool::new(false),
52 waker: AtomicWaker::new(),
53 }))
54 }
55
56 pub fn try_unwrap(self) -> Result<T, Self> {
58 let this = ManuallyDrop::new(self);
59 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
60 Ok(fd)
61 } else {
62 Err(ManuallyDrop::into_inner(this))
63 }
64 }
65
66 unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
68 let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
69 match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
71 Ok(inner) => Some(inner.fd),
72 Err(ptr) => {
73 std::mem::forget(ptr);
74 None
75 }
76 }
77 }
78
79 pub fn take(self) -> impl Future<Output = Option<T>> {
81 let this = ManuallyDrop::new(self);
82 async move {
83 if !this.0.waits.swap(true, Ordering::AcqRel) {
84 poll_fn(move |cx| {
85 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
86 return Poll::Ready(Some(fd));
87 }
88
89 this.0.waker.register(cx.waker());
90
91 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
92 Poll::Ready(Some(fd))
93 } else {
94 Poll::Pending
95 }
96 })
97 .await
98 } else {
99 None
100 }
101 }
102 }
103}
104
105impl<T> Drop for SharedFd<T> {
106 fn drop(&mut self) {
107 if Arc::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
109 self.0.waker.wake()
110 }
111 }
112}
113
114impl<T: AsFd> AsFd for SharedFd<T> {
115 fn as_fd(&self) -> BorrowedFd<'_> {
116 self.0.fd.as_fd()
117 }
118}
119
120impl<T: AsFd> AsRawFd for SharedFd<T> {
121 fn as_raw_fd(&self) -> RawFd {
122 self.as_fd().as_raw_fd()
123 }
124}
125
126#[cfg(windows)]
127impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
128 unsafe fn from_raw_handle(handle: RawHandle) -> Self {
129 Self::new_unchecked(T::from_raw_handle(handle))
130 }
131}
132
133#[cfg(windows)]
134impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
135 unsafe fn from_raw_socket(sock: RawSocket) -> Self {
136 Self::new_unchecked(T::from_raw_socket(sock))
137 }
138}
139
140#[cfg(unix)]
141impl<T: FromRawFd> FromRawFd for SharedFd<T> {
142 unsafe fn from_raw_fd(fd: RawFd) -> Self {
143 Self::new_unchecked(T::from_raw_fd(fd))
144 }
145}
146
147impl<T> Clone for SharedFd<T> {
148 fn clone(&self) -> Self {
149 Self(self.0.clone())
150 }
151}
152
153impl<T> Deref for SharedFd<T> {
154 type Target = T;
155
156 fn deref(&self) -> &Self::Target {
157 &self.0.fd
158 }
159}
160
161pub trait ToSharedFd<T> {
163 fn to_shared_fd(&self) -> SharedFd<T>;
165}
166
167impl<T> ToSharedFd<T> for SharedFd<T> {
168 fn to_shared_fd(&self) -> SharedFd<T> {
169 self.clone()
170 }
171}