1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6 future::{Future, poll_fn},
7 mem::ManuallyDrop,
8 ops::Deref,
9 panic::RefUnwindSafe,
10 sync::atomic::Ordering,
11 task::Poll,
12};
13
14use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
15
16cfg_if::cfg_if! {
17 if #[cfg(feature = "fd-sync")] {
18 #[path = "sync.rs"]
19 mod types;
20 } else {
21 #[path = "unsync.rs"]
22 mod types;
23 }
24}
25
26use types::{RefPtr, WaitFlag, WakerRegistry};
27
28#[derive(Debug)]
29struct Inner<T> {
30 fd: T,
31 waits: WaitFlag,
33 waker: WakerRegistry,
34}
35
36impl<T> RefUnwindSafe for Inner<T> {}
37
38#[derive(Debug)]
41pub struct SharedFd<T>(RefPtr<Inner<T>>);
42
43impl<T: AsFd> SharedFd<T> {
44 pub fn new(fd: T) -> Self {
46 unsafe { Self::new_unchecked(fd) }
47 }
48}
49
50impl<T> SharedFd<T> {
51 pub unsafe fn new_unchecked(fd: T) -> Self {
56 Self(RefPtr::new(Inner {
57 fd,
58 waits: WaitFlag::new(false),
59 waker: WakerRegistry::new(),
60 }))
61 }
62
63 pub fn try_unwrap(self) -> Result<T, Self> {
65 let this = ManuallyDrop::new(self);
66 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
67 Ok(fd)
68 } else {
69 Err(ManuallyDrop::into_inner(this))
70 }
71 }
72
73 unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
75 let ptr = unsafe { std::ptr::read(&this.0) };
77 match RefPtr::try_unwrap(ptr) {
79 Ok(inner) => Some(inner.fd),
80 Err(ptr) => {
81 std::mem::forget(ptr);
82 None
83 }
84 }
85 }
86
87 pub fn take(self) -> impl Future<Output = Option<T>> {
89 let this = ManuallyDrop::new(self);
90 async move {
91 if !this.0.waits.swap(true, Ordering::AcqRel) {
92 poll_fn(move |cx| {
93 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
94 return Poll::Ready(Some(fd));
95 }
96
97 this.0.waker.register(cx.waker());
98
99 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
100 Poll::Ready(Some(fd))
101 } else {
102 Poll::Pending
103 }
104 })
105 .await
106 } else {
107 None
108 }
109 }
110 }
111}
112
113impl<T> Drop for SharedFd<T> {
114 fn drop(&mut self) {
115 if RefPtr::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
117 self.0.waker.wake()
118 }
119 }
120}
121
122impl<T: AsFd> AsFd for SharedFd<T> {
123 fn as_fd(&self) -> BorrowedFd<'_> {
124 self.0.fd.as_fd()
125 }
126}
127
128impl<T: AsFd> AsRawFd for SharedFd<T> {
129 fn as_raw_fd(&self) -> RawFd {
130 self.as_fd().as_raw_fd()
131 }
132}
133
134#[cfg(windows)]
135impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
136 unsafe fn from_raw_handle(handle: RawHandle) -> Self {
137 unsafe { Self::new_unchecked(T::from_raw_handle(handle)) }
138 }
139}
140
141#[cfg(windows)]
142impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
143 unsafe fn from_raw_socket(sock: RawSocket) -> Self {
144 unsafe { Self::new_unchecked(T::from_raw_socket(sock)) }
145 }
146}
147
148#[cfg(unix)]
149impl<T: FromRawFd> FromRawFd for SharedFd<T> {
150 unsafe fn from_raw_fd(fd: RawFd) -> Self {
151 unsafe { Self::new_unchecked(T::from_raw_fd(fd)) }
152 }
153}
154
155impl<T> Clone for SharedFd<T> {
156 fn clone(&self) -> Self {
157 Self(self.0.clone())
158 }
159}
160
161impl<T> Deref for SharedFd<T> {
162 type Target = T;
163
164 fn deref(&self) -> &Self::Target {
165 &self.0.fd
166 }
167}
168
169pub trait ToSharedFd<T> {
171 fn to_shared_fd(&self) -> SharedFd<T>;
173}
174
175impl<T> ToSharedFd<T> for SharedFd<T> {
176 fn to_shared_fd(&self) -> SharedFd<T> {
177 self.clone()
178 }
179}