compio_driver/
fd.rs

1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6    future::{Future, poll_fn},
7    mem::ManuallyDrop,
8    ops::Deref,
9    panic::RefUnwindSafe,
10    sync::{
11        Arc,
12        atomic::{AtomicBool, Ordering},
13    },
14    task::Poll,
15};
16
17use futures_util::task::AtomicWaker;
18
19use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
20
21#[derive(Debug)]
22struct Inner<T> {
23    fd: T,
24    // whether there is a future waiting
25    waits: AtomicBool,
26    waker: AtomicWaker,
27}
28
29impl<T> RefUnwindSafe for Inner<T> {}
30
31/// A shared fd. It is passed to the operations to make sure the fd won't be
32/// closed before the operations complete.
33#[derive(Debug)]
34pub struct SharedFd<T>(Arc<Inner<T>>);
35
36impl<T: AsFd> SharedFd<T> {
37    /// Create the shared fd from an owned fd.
38    pub fn new(fd: T) -> Self {
39        unsafe { Self::new_unchecked(fd) }
40    }
41}
42
43impl<T> SharedFd<T> {
44    /// Create the shared fd.
45    ///
46    /// # Safety
47    /// * T should own the fd.
48    pub unsafe fn new_unchecked(fd: T) -> Self {
49        Self(Arc::new(Inner {
50            fd,
51            waits: AtomicBool::new(false),
52            waker: AtomicWaker::new(),
53        }))
54    }
55
56    /// Try to take the inner owned fd.
57    pub fn try_unwrap(self) -> Result<T, Self> {
58        let this = ManuallyDrop::new(self);
59        if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
60            Ok(fd)
61        } else {
62            Err(ManuallyDrop::into_inner(this))
63        }
64    }
65
66    // SAFETY: if `Some` is returned, the method should not be called again.
67    unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
68        let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
69        // The ptr is duplicated without increasing the strong count, should forget.
70        match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
71            Ok(inner) => Some(inner.fd),
72            Err(ptr) => {
73                std::mem::forget(ptr);
74                None
75            }
76        }
77    }
78
79    /// Wait and take the inner owned fd.
80    pub fn take(self) -> impl Future<Output = Option<T>> {
81        let this = ManuallyDrop::new(self);
82        async move {
83            if !this.0.waits.swap(true, Ordering::AcqRel) {
84                poll_fn(move |cx| {
85                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
86                        return Poll::Ready(Some(fd));
87                    }
88
89                    this.0.waker.register(cx.waker());
90
91                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
92                        Poll::Ready(Some(fd))
93                    } else {
94                        Poll::Pending
95                    }
96                })
97                .await
98            } else {
99                None
100            }
101        }
102    }
103}
104
105impl<T> Drop for SharedFd<T> {
106    fn drop(&mut self) {
107        // It's OK to wake multiple times.
108        if Arc::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
109            self.0.waker.wake()
110        }
111    }
112}
113
114impl<T: AsFd> AsFd for SharedFd<T> {
115    fn as_fd(&self) -> BorrowedFd<'_> {
116        self.0.fd.as_fd()
117    }
118}
119
120impl<T: AsFd> AsRawFd for SharedFd<T> {
121    fn as_raw_fd(&self) -> RawFd {
122        self.as_fd().as_raw_fd()
123    }
124}
125
126#[cfg(windows)]
127impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
128    unsafe fn from_raw_handle(handle: RawHandle) -> Self {
129        Self::new_unchecked(T::from_raw_handle(handle))
130    }
131}
132
133#[cfg(windows)]
134impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
135    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
136        Self::new_unchecked(T::from_raw_socket(sock))
137    }
138}
139
140#[cfg(unix)]
141impl<T: FromRawFd> FromRawFd for SharedFd<T> {
142    unsafe fn from_raw_fd(fd: RawFd) -> Self {
143        Self::new_unchecked(T::from_raw_fd(fd))
144    }
145}
146
147impl<T> Clone for SharedFd<T> {
148    fn clone(&self) -> Self {
149        Self(self.0.clone())
150    }
151}
152
153impl<T> Deref for SharedFd<T> {
154    type Target = T;
155
156    fn deref(&self) -> &Self::Target {
157        &self.0.fd
158    }
159}
160
161/// Get a clone of [`SharedFd`].
162pub trait ToSharedFd<T> {
163    /// Return a cloned [`SharedFd`].
164    fn to_shared_fd(&self) -> SharedFd<T>;
165}
166
167impl<T> ToSharedFd<T> for SharedFd<T> {
168    fn to_shared_fd(&self) -> SharedFd<T> {
169        self.clone()
170    }
171}