1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6 future::{Future, poll_fn},
7 mem::ManuallyDrop,
8 ops::Deref,
9 panic::RefUnwindSafe,
10 sync::atomic::Ordering,
11 task::Poll,
12};
13
14use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
15
16cfg_if::cfg_if! {
17 if #[cfg(feature = "sync")] {
18 use synchrony::sync;
19 } else {
20 use synchrony::unsync as sync;
21 }
22}
23
24use sync::{atomic::AtomicBool, shared::Shared, waker_slot::WakerSlot};
25
26#[derive(Debug)]
27struct Inner<T> {
28 fd: T,
29 waits: AtomicBool,
31 waker: WakerSlot,
32}
33
34impl<T> RefUnwindSafe for Inner<T> {}
35
36#[derive(Debug)]
39pub struct SharedFd<T>(Shared<Inner<T>>);
40
41impl<T: AsFd> SharedFd<T> {
42 pub fn new(fd: T) -> Self {
44 unsafe { Self::new_unchecked(fd) }
45 }
46}
47
48impl<T> SharedFd<T> {
49 pub unsafe fn new_unchecked(fd: T) -> Self {
54 Self(Shared::new(Inner {
55 fd,
56 waits: AtomicBool::new(false),
57 waker: WakerSlot::new(),
58 }))
59 }
60
61 pub fn try_unwrap(self) -> Result<T, Self> {
63 let this = ManuallyDrop::new(self);
64 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
65 Ok(fd)
66 } else {
67 Err(ManuallyDrop::into_inner(this))
68 }
69 }
70
71 unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
73 let ptr = unsafe { std::ptr::read(&this.0) };
75 match Shared::try_unwrap(ptr) {
77 Ok(inner) => Some(inner.fd),
78 Err(ptr) => {
79 std::mem::forget(ptr);
80 None
81 }
82 }
83 }
84
85 pub fn take(self) -> impl Future<Output = Option<T>> {
87 let this = ManuallyDrop::new(self);
88 async move {
89 if !this.0.waits.swap(true, Ordering::AcqRel) {
90 poll_fn(move |cx| {
91 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
92 return Poll::Ready(Some(fd));
93 }
94
95 this.0.waker.register(cx.waker());
96
97 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
98 Poll::Ready(Some(fd))
99 } else {
100 Poll::Pending
101 }
102 })
103 .await
104 } else {
105 None
106 }
107 }
108 }
109}
110
111impl<T> Drop for SharedFd<T> {
112 fn drop(&mut self) {
113 if Shared::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
115 self.0.waker.wake()
116 }
117 }
118}
119
120impl<T: AsFd> AsFd for SharedFd<T> {
121 fn as_fd(&self) -> BorrowedFd<'_> {
122 self.0.fd.as_fd()
123 }
124}
125
126impl<T: AsFd> AsRawFd for SharedFd<T> {
127 fn as_raw_fd(&self) -> RawFd {
128 self.as_fd().as_raw_fd()
129 }
130}
131
132#[cfg(windows)]
133impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
134 unsafe fn from_raw_handle(handle: RawHandle) -> Self {
135 unsafe { Self::new_unchecked(T::from_raw_handle(handle)) }
136 }
137}
138
139#[cfg(windows)]
140impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
141 unsafe fn from_raw_socket(sock: RawSocket) -> Self {
142 unsafe { Self::new_unchecked(T::from_raw_socket(sock)) }
143 }
144}
145
146#[cfg(unix)]
147impl<T: FromRawFd> FromRawFd for SharedFd<T> {
148 unsafe fn from_raw_fd(fd: RawFd) -> Self {
149 unsafe { Self::new_unchecked(T::from_raw_fd(fd)) }
150 }
151}
152
153impl<T> Clone for SharedFd<T> {
154 fn clone(&self) -> Self {
155 Self(self.0.clone())
156 }
157}
158
159impl<T> Deref for SharedFd<T> {
160 type Target = T;
161
162 fn deref(&self) -> &Self::Target {
163 &self.0.fd
164 }
165}
166
167pub trait ToSharedFd<T> {
169 fn to_shared_fd(&self) -> SharedFd<T>;
171}
172
173impl<T> ToSharedFd<T> for SharedFd<T> {
174 fn to_shared_fd(&self) -> SharedFd<T> {
175 self.clone()
176 }
177}