compio_driver/fd/
mod.rs

1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6    future::{Future, poll_fn},
7    mem::ManuallyDrop,
8    ops::Deref,
9    panic::RefUnwindSafe,
10    sync::atomic::Ordering,
11    task::Poll,
12};
13
14use crate::{AsFd, AsRawFd, BorrowedFd, RawFd};
15
16cfg_if::cfg_if! {
17    if #[cfg(feature = "fd-sync")] {
18        #[path = "sync.rs"]
19        mod types;
20    } else {
21        #[path = "unsync.rs"]
22        mod types;
23    }
24}
25
26use types::{RefPtr, WaitFlag, WakerRegistry};
27
28#[derive(Debug)]
29struct Inner<T> {
30    fd: T,
31    // whether there is a future waiting
32    waits: WaitFlag,
33    waker: WakerRegistry,
34}
35
36impl<T> RefUnwindSafe for Inner<T> {}
37
38/// A shared fd. It is passed to the operations to make sure the fd won't be
39/// closed before the operations complete.
40#[derive(Debug)]
41pub struct SharedFd<T>(RefPtr<Inner<T>>);
42
43impl<T: AsFd> SharedFd<T> {
44    /// Create the shared fd from an owned fd.
45    pub fn new(fd: T) -> Self {
46        unsafe { Self::new_unchecked(fd) }
47    }
48}
49
50impl<T> SharedFd<T> {
51    /// Create the shared fd.
52    ///
53    /// # Safety
54    /// * T should own the fd.
55    pub unsafe fn new_unchecked(fd: T) -> Self {
56        Self(RefPtr::new(Inner {
57            fd,
58            waits: WaitFlag::new(false),
59            waker: WakerRegistry::new(),
60        }))
61    }
62
63    /// Try to take the inner owned fd.
64    pub fn try_unwrap(self) -> Result<T, Self> {
65        let this = ManuallyDrop::new(self);
66        if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
67            Ok(fd)
68        } else {
69            Err(ManuallyDrop::into_inner(this))
70        }
71    }
72
73    // SAFETY: if `Some` is returned, the method should not be called again.
74    unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
75        // SAFETY: `this` is not dropped here.
76        let ptr = unsafe { std::ptr::read(&this.0) };
77        // The ptr is duplicated without increasing the strong count, should forget.
78        match RefPtr::try_unwrap(ptr) {
79            Ok(inner) => Some(inner.fd),
80            Err(ptr) => {
81                std::mem::forget(ptr);
82                None
83            }
84        }
85    }
86
87    /// Wait and take the inner owned fd.
88    pub fn take(self) -> impl Future<Output = Option<T>> {
89        let this = ManuallyDrop::new(self);
90        async move {
91            if !this.0.waits.swap(true, Ordering::AcqRel) {
92                poll_fn(move |cx| {
93                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
94                        return Poll::Ready(Some(fd));
95                    }
96
97                    this.0.waker.register(cx.waker());
98
99                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
100                        Poll::Ready(Some(fd))
101                    } else {
102                        Poll::Pending
103                    }
104                })
105                .await
106            } else {
107                None
108            }
109        }
110    }
111}
112
113impl<T> Drop for SharedFd<T> {
114    fn drop(&mut self) {
115        // It's OK to wake multiple times.
116        if RefPtr::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
117            self.0.waker.wake()
118        }
119    }
120}
121
122impl<T: AsFd> AsFd for SharedFd<T> {
123    fn as_fd(&self) -> BorrowedFd<'_> {
124        self.0.fd.as_fd()
125    }
126}
127
128impl<T: AsFd> AsRawFd for SharedFd<T> {
129    fn as_raw_fd(&self) -> RawFd {
130        self.as_fd().as_raw_fd()
131    }
132}
133
134#[cfg(windows)]
135impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
136    unsafe fn from_raw_handle(handle: RawHandle) -> Self {
137        unsafe { Self::new_unchecked(T::from_raw_handle(handle)) }
138    }
139}
140
141#[cfg(windows)]
142impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
143    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
144        unsafe { Self::new_unchecked(T::from_raw_socket(sock)) }
145    }
146}
147
148#[cfg(unix)]
149impl<T: FromRawFd> FromRawFd for SharedFd<T> {
150    unsafe fn from_raw_fd(fd: RawFd) -> Self {
151        unsafe { Self::new_unchecked(T::from_raw_fd(fd)) }
152    }
153}
154
155impl<T> Clone for SharedFd<T> {
156    fn clone(&self) -> Self {
157        Self(self.0.clone())
158    }
159}
160
161impl<T> Deref for SharedFd<T> {
162    type Target = T;
163
164    fn deref(&self) -> &Self::Target {
165        &self.0.fd
166    }
167}
168
169/// Get a clone of [`SharedFd`].
170pub trait ToSharedFd<T> {
171    /// Return a cloned [`SharedFd`].
172    fn to_shared_fd(&self) -> SharedFd<T>;
173}
174
175impl<T> ToSharedFd<T> for SharedFd<T> {
176    fn to_shared_fd(&self) -> SharedFd<T> {
177        self.clone()
178    }
179}