compio_driver/iocp/
mod.rs

1use std::{
2    collections::HashMap,
3    io,
4    os::windows::io::{
5        AsHandle, AsRawHandle, AsRawSocket, AsSocket, BorrowedHandle, BorrowedSocket, OwnedHandle,
6        OwnedSocket,
7    },
8    pin::Pin,
9    sync::Arc,
10    task::{Poll, Wake, Waker},
11    time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{
16    Foundation::{ERROR_CANCELLED, HANDLE},
17    System::IO::OVERLAPPED,
18};
19
20use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder};
21
22pub(crate) mod op;
23
24mod cp;
25mod wait;
26
27/// On windows, handle and socket are in the same size.
28/// Both of them could be attached to an IOCP.
29/// Therefore, both could be seen as fd.
30pub type RawFd = HANDLE;
31
32/// Extracts raw fds.
33pub trait AsRawFd {
34    /// Extracts the raw fd.
35    fn as_raw_fd(&self) -> RawFd;
36}
37
38/// Owned handle or socket on Windows.
39#[derive(Debug)]
40pub enum OwnedFd {
41    /// Win32 handle.
42    File(OwnedHandle),
43    /// Windows socket handle.
44    Socket(OwnedSocket),
45}
46
47impl AsRawFd for OwnedFd {
48    fn as_raw_fd(&self) -> RawFd {
49        match self {
50            Self::File(fd) => fd.as_raw_handle() as _,
51            Self::Socket(s) => s.as_raw_socket() as _,
52        }
53    }
54}
55
56impl AsRawFd for RawFd {
57    fn as_raw_fd(&self) -> RawFd {
58        *self
59    }
60}
61
62impl AsRawFd for std::fs::File {
63    fn as_raw_fd(&self) -> RawFd {
64        self.as_raw_handle() as _
65    }
66}
67
68impl AsRawFd for OwnedHandle {
69    fn as_raw_fd(&self) -> RawFd {
70        self.as_raw_handle() as _
71    }
72}
73
74impl AsRawFd for socket2::Socket {
75    fn as_raw_fd(&self) -> RawFd {
76        self.as_raw_socket() as _
77    }
78}
79
80impl AsRawFd for OwnedSocket {
81    fn as_raw_fd(&self) -> RawFd {
82        self.as_raw_socket() as _
83    }
84}
85
86impl AsRawFd for std::process::ChildStdin {
87    fn as_raw_fd(&self) -> RawFd {
88        self.as_raw_handle() as _
89    }
90}
91
92impl AsRawFd for std::process::ChildStdout {
93    fn as_raw_fd(&self) -> RawFd {
94        self.as_raw_handle() as _
95    }
96}
97
98impl AsRawFd for std::process::ChildStderr {
99    fn as_raw_fd(&self) -> RawFd {
100        self.as_raw_handle() as _
101    }
102}
103
104impl From<OwnedHandle> for OwnedFd {
105    fn from(value: OwnedHandle) -> Self {
106        Self::File(value)
107    }
108}
109
110impl From<std::fs::File> for OwnedFd {
111    fn from(value: std::fs::File) -> Self {
112        Self::File(OwnedHandle::from(value))
113    }
114}
115
116impl From<std::process::ChildStdin> for OwnedFd {
117    fn from(value: std::process::ChildStdin) -> Self {
118        Self::File(OwnedHandle::from(value))
119    }
120}
121
122impl From<std::process::ChildStdout> for OwnedFd {
123    fn from(value: std::process::ChildStdout) -> Self {
124        Self::File(OwnedHandle::from(value))
125    }
126}
127
128impl From<std::process::ChildStderr> for OwnedFd {
129    fn from(value: std::process::ChildStderr) -> Self {
130        Self::File(OwnedHandle::from(value))
131    }
132}
133
134impl From<OwnedSocket> for OwnedFd {
135    fn from(value: OwnedSocket) -> Self {
136        Self::Socket(value)
137    }
138}
139
140impl From<socket2::Socket> for OwnedFd {
141    fn from(value: socket2::Socket) -> Self {
142        Self::Socket(OwnedSocket::from(value))
143    }
144}
145
146/// Borrowed handle or socket on Windows.
147#[derive(Debug)]
148pub enum BorrowedFd<'a> {
149    /// Win32 handle.
150    File(BorrowedHandle<'a>),
151    /// Windows socket handle.
152    Socket(BorrowedSocket<'a>),
153}
154
155impl AsRawFd for BorrowedFd<'_> {
156    fn as_raw_fd(&self) -> RawFd {
157        match self {
158            Self::File(fd) => fd.as_raw_handle() as RawFd,
159            Self::Socket(s) => s.as_raw_socket() as RawFd,
160        }
161    }
162}
163
164impl<'a> From<BorrowedHandle<'a>> for BorrowedFd<'a> {
165    fn from(value: BorrowedHandle<'a>) -> Self {
166        Self::File(value)
167    }
168}
169
170impl<'a> From<BorrowedSocket<'a>> for BorrowedFd<'a> {
171    fn from(value: BorrowedSocket<'a>) -> Self {
172        Self::Socket(value)
173    }
174}
175
176/// Extracts fds.
177pub trait AsFd {
178    /// Extracts the borrowed fd.
179    fn as_fd(&self) -> BorrowedFd<'_>;
180}
181
182impl AsFd for OwnedFd {
183    fn as_fd(&self) -> BorrowedFd<'_> {
184        match self {
185            Self::File(fd) => fd.as_fd(),
186            Self::Socket(s) => s.as_fd(),
187        }
188    }
189}
190
191impl AsFd for std::fs::File {
192    fn as_fd(&self) -> BorrowedFd<'_> {
193        self.as_handle().into()
194    }
195}
196
197impl AsFd for OwnedHandle {
198    fn as_fd(&self) -> BorrowedFd<'_> {
199        self.as_handle().into()
200    }
201}
202
203impl AsFd for socket2::Socket {
204    fn as_fd(&self) -> BorrowedFd<'_> {
205        self.as_socket().into()
206    }
207}
208
209impl AsFd for OwnedSocket {
210    fn as_fd(&self) -> BorrowedFd<'_> {
211        self.as_socket().into()
212    }
213}
214
215impl AsFd for std::process::ChildStdin {
216    fn as_fd(&self) -> BorrowedFd<'_> {
217        self.as_handle().into()
218    }
219}
220
221impl AsFd for std::process::ChildStdout {
222    fn as_fd(&self) -> BorrowedFd<'_> {
223        self.as_handle().into()
224    }
225}
226
227impl AsFd for std::process::ChildStderr {
228    fn as_fd(&self) -> BorrowedFd<'_> {
229        self.as_handle().into()
230    }
231}
232
233/// Operation type.
234pub enum OpType {
235    /// An overlapped operation.
236    Overlapped,
237    /// A blocking operation, needs a thread to spawn. The `operate` method
238    /// should be thread safe.
239    Blocking,
240    /// A Win32 event object to be waited. The user should ensure that the
241    /// handle is valid till operation completes. The `operate` method should be
242    /// thread safe.
243    Event(RawFd),
244}
245
246/// Abstraction of IOCP operations.
247pub trait OpCode {
248    /// Determines that the operation is really overlapped defined by Windows
249    /// API. If not, the driver will try to operate it in another thread.
250    fn op_type(&self) -> OpType {
251        OpType::Overlapped
252    }
253
254    /// Perform Windows API call with given pointer to overlapped struct.
255    ///
256    /// It is always safe to cast `optr` to a pointer to
257    /// [`Overlapped<Self>`].
258    ///
259    /// Don't do heavy work here if [`OpCode::op_type`] returns
260    /// [`OpType::Event`].
261    ///
262    /// # Safety
263    ///
264    /// * `self` must be alive until the operation completes.
265    /// * When [`OpCode::op_type`] returns [`OpType::Blocking`], this method is
266    ///   called in another thread.
267    unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
268
269    /// Cancel the async IO operation.
270    ///
271    /// Usually it calls `CancelIoEx`.
272    ///
273    /// # Safety
274    ///
275    /// * Should not use [`Overlapped::op`].
276    unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
277        let _optr = optr; // ignore it
278        Ok(())
279    }
280}
281
282/// Low-level driver of IOCP.
283pub(crate) struct Driver {
284    notify: Arc<Notify>,
285    waits: HashMap<usize, wait::Wait>,
286    pool: AsyncifyPool,
287}
288
289impl Driver {
290    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
291        instrument!(compio_log::Level::TRACE, "new", ?builder);
292
293        let port = cp::Port::new()?;
294        let driver = port.as_raw_handle() as _;
295        let overlapped = Overlapped::new(driver);
296        let notify = Arc::new(Notify::new(port, overlapped));
297        Ok(Self {
298            notify,
299            waits: HashMap::default(),
300            pool: builder.create_or_get_thread_pool(),
301        })
302    }
303
304    pub fn driver_type(&self) -> DriverType {
305        DriverType::IOCP
306    }
307
308    fn port(&self) -> &cp::Port {
309        &self.notify.port
310    }
311
312    pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
313        Key::new(self.port().as_raw_handle() as _, op)
314    }
315
316    pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
317        self.port().attach(fd)
318    }
319
320    pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
321        instrument!(compio_log::Level::TRACE, "cancel", ?op);
322        trace!("cancel RawOp");
323        let overlapped_ptr = op.as_mut_ptr();
324        if let Some(w) = self.waits.get_mut(&op.user_data())
325            && w.cancel().is_ok()
326        {
327            // The pack has been cancelled successfully, which means no packet will be post
328            // to IOCP. Need not set the result because `create_entry` handles it.
329            self.port().post_raw(overlapped_ptr).ok();
330        }
331        let op = op.as_op_pin();
332        // It's OK to fail to cancel.
333        trace!("call OpCode::cancel");
334        unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
335    }
336
337    pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
338        instrument!(compio_log::Level::TRACE, "push", ?op);
339        let user_data = op.user_data();
340        trace!("push RawOp");
341        let optr = op.as_mut_ptr();
342        let op_pin = op.as_op_pin();
343        match op_pin.op_type() {
344            OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
345            OpType::Blocking => loop {
346                if self.push_blocking(user_data) {
347                    break Poll::Pending;
348                } else {
349                    // It's OK to wait forever, because any blocking task will notify the IOCP after
350                    // it completes.
351                    self.poll(None)?;
352                }
353            },
354            OpType::Event(e) => {
355                self.waits
356                    .insert(user_data, wait::Wait::new(self.notify.clone(), e, op)?);
357                Poll::Pending
358            }
359        }
360    }
361
362    fn push_blocking(&mut self, user_data: usize) -> bool {
363        let notify = self.notify.clone();
364        self.pool
365            .dispatch(move || {
366                let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
367                let optr = op.as_mut_ptr();
368                let res = op.operate_blocking();
369                notify.port.post(res, optr).ok();
370            })
371            .is_ok()
372    }
373
374    fn create_entry(
375        notify_user_data: usize,
376        waits: &mut HashMap<usize, wait::Wait>,
377        entry: Entry,
378    ) -> Option<Entry> {
379        let user_data = entry.user_data();
380        if user_data != notify_user_data {
381            if let Some(w) = waits.remove(&user_data) {
382                if w.is_cancelled() {
383                    Some(Entry::new(
384                        user_data,
385                        Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
386                    ))
387                } else if entry.result.is_err() {
388                    Some(entry)
389                } else {
390                    let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
391                    let result = op.operate_blocking();
392                    Some(Entry::new(user_data, result))
393                }
394            } else {
395                Some(entry)
396            }
397        } else {
398            None
399        }
400    }
401
402    pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
403        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
404
405        let notify_user_data = &self.notify.overlapped as *const Overlapped as usize;
406
407        for e in self.notify.port.poll(timeout)? {
408            if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
409                // SAFETY: called only once.
410                unsafe { e.notify() }
411            }
412        }
413
414        Ok(())
415    }
416
417    pub fn waker(&self) -> Waker {
418        Waker::from(self.notify.clone())
419    }
420
421    pub fn create_buffer_pool(
422        &mut self,
423        buffer_len: u16,
424        buffer_size: usize,
425    ) -> io::Result<BufferPool> {
426        Ok(BufferPool::new(buffer_len, buffer_size))
427    }
428
429    /// # Safety
430    ///
431    /// caller must make sure release the buffer pool with correct driver
432    pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
433        Ok(())
434    }
435}
436
437impl AsRawFd for Driver {
438    fn as_raw_fd(&self) -> RawFd {
439        self.port().as_raw_handle() as _
440    }
441}
442
443/// A notify handle to the inner driver.
444struct Notify {
445    port: cp::Port,
446    overlapped: Overlapped,
447}
448
449impl Notify {
450    fn new(port: cp::Port, overlapped: Overlapped) -> Self {
451        Self { port, overlapped }
452    }
453
454    /// Notify the inner driver.
455    pub fn notify(&self) -> io::Result<()> {
456        self.port.post_raw(&self.overlapped)
457    }
458}
459
460impl Wake for Notify {
461    fn wake(self: Arc<Self>) {
462        self.wake_by_ref();
463    }
464
465    fn wake_by_ref(self: &Arc<Self>) {
466        self.notify().ok();
467    }
468}
469
470/// The overlapped struct we actually used for IOCP.
471#[repr(C)]
472pub struct Overlapped {
473    /// The base [`OVERLAPPED`].
474    pub base: OVERLAPPED,
475    /// The unique ID of created driver.
476    pub driver: RawFd,
477}
478
479impl Overlapped {
480    pub(crate) fn new(driver: RawFd) -> Self {
481        Self {
482            base: unsafe { std::mem::zeroed() },
483            driver,
484        }
485    }
486}
487
488// SAFETY: neither field of `OVERLAPPED` is used
489unsafe impl Send for Overlapped {}
490unsafe impl Sync for Overlapped {}