Skip to main content

compio_driver/
fd.rs

1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6    future::{Future, poll_fn},
7    mem::ManuallyDrop,
8    ops::Deref,
9    panic::RefUnwindSafe,
10    sync::atomic::Ordering,
11    task::Poll,
12};
13
14use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
15
16cfg_if::cfg_if! {
17    if #[cfg(feature = "sync")] {
18        use synchrony::sync;
19    } else {
20        use synchrony::unsync as sync;
21    }
22}
23
24use sync::{atomic::AtomicBool, shared::Shared, waker_slot::WakerSlot};
25
26#[derive(Debug)]
27struct Inner<T> {
28    fd: T,
29    // whether there is a future waiting
30    waits: AtomicBool,
31    waker: WakerSlot,
32}
33
34impl<T> RefUnwindSafe for Inner<T> {}
35
36/// A shared fd. It is passed to the operations to make sure the fd won't be
37/// closed before the operations complete.
38#[derive(Debug)]
39pub struct SharedFd<T>(Shared<Inner<T>>);
40
41impl<T: AsFd> SharedFd<T> {
42    /// Create the shared fd from an owned fd.
43    pub fn new(fd: T) -> Self {
44        unsafe { Self::new_unchecked(fd) }
45    }
46}
47
48impl<T> SharedFd<T> {
49    /// Create the shared fd.
50    ///
51    /// # Safety
52    /// * T should own the fd.
53    pub unsafe fn new_unchecked(fd: T) -> Self {
54        Self(Shared::new(Inner {
55            fd,
56            waits: AtomicBool::new(false),
57            waker: WakerSlot::new(),
58        }))
59    }
60
61    /// Try to take the inner owned fd.
62    pub fn try_unwrap(self) -> Result<T, Self> {
63        let this = ManuallyDrop::new(self);
64        if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
65            Ok(fd)
66        } else {
67            Err(ManuallyDrop::into_inner(this))
68        }
69    }
70
71    // SAFETY: if `Some` is returned, the method should not be called again.
72    unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
73        // SAFETY: `this` is not dropped here.
74        let ptr = unsafe { std::ptr::read(&this.0) };
75        // The ptr is duplicated without increasing the strong count, should forget.
76        match Shared::try_unwrap(ptr) {
77            Ok(inner) => Some(inner.fd),
78            Err(ptr) => {
79                std::mem::forget(ptr);
80                None
81            }
82        }
83    }
84
85    /// Wait and take the inner owned fd.
86    pub fn take(self) -> impl Future<Output = Option<T>> {
87        let this = ManuallyDrop::new(self);
88        async move {
89            if !this.0.waits.swap(true, Ordering::AcqRel) {
90                poll_fn(move |cx| {
91                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
92                        return Poll::Ready(Some(fd));
93                    }
94
95                    this.0.waker.register(cx.waker());
96
97                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
98                        Poll::Ready(Some(fd))
99                    } else {
100                        Poll::Pending
101                    }
102                })
103                .await
104            } else {
105                None
106            }
107        }
108    }
109}
110
111impl<T> Drop for SharedFd<T> {
112    fn drop(&mut self) {
113        // It's OK to wake multiple times.
114        if Shared::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
115            self.0.waker.wake()
116        }
117    }
118}
119
120impl<T: AsFd> AsFd for SharedFd<T> {
121    fn as_fd(&self) -> BorrowedFd<'_> {
122        self.0.fd.as_fd()
123    }
124}
125
126impl<T: AsFd> AsRawFd for SharedFd<T> {
127    fn as_raw_fd(&self) -> RawFd {
128        self.as_fd().as_raw_fd()
129    }
130}
131
132#[cfg(windows)]
133impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
134    unsafe fn from_raw_handle(handle: RawHandle) -> Self {
135        unsafe { Self::new_unchecked(T::from_raw_handle(handle)) }
136    }
137}
138
139#[cfg(windows)]
140impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
141    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
142        unsafe { Self::new_unchecked(T::from_raw_socket(sock)) }
143    }
144}
145
146#[cfg(unix)]
147impl<T: FromRawFd> FromRawFd for SharedFd<T> {
148    unsafe fn from_raw_fd(fd: RawFd) -> Self {
149        unsafe { Self::new_unchecked(T::from_raw_fd(fd)) }
150    }
151}
152
153impl<T> Clone for SharedFd<T> {
154    fn clone(&self) -> Self {
155        Self(self.0.clone())
156    }
157}
158
159impl<T> Deref for SharedFd<T> {
160    type Target = T;
161
162    fn deref(&self) -> &Self::Target {
163        &self.0.fd
164    }
165}
166
167/// Get a clone of [`SharedFd`].
168pub trait ToSharedFd<T> {
169    /// Return a cloned [`SharedFd`].
170    fn to_shared_fd(&self) -> SharedFd<T>;
171}
172
173impl<T> ToSharedFd<T> for SharedFd<T> {
174    fn to_shared_fd(&self) -> SharedFd<T> {
175        self.clone()
176    }
177}