compio_driver/iocp/
mod.rs

1use std::{
2    collections::HashMap,
3    io,
4    os::windows::io::{
5        AsHandle, AsRawHandle, AsRawSocket, AsSocket, BorrowedHandle, BorrowedSocket, OwnedHandle,
6        OwnedSocket,
7    },
8    pin::Pin,
9    sync::Arc,
10    task::Poll,
11    time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{
16    Foundation::{ERROR_CANCELLED, HANDLE},
17    System::IO::OVERLAPPED,
18};
19
20use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder};
21
22pub(crate) mod op;
23
24mod cp;
25mod wait;
26
27/// On windows, handle and socket are in the same size.
28/// Both of them could be attached to an IOCP.
29/// Therefore, both could be seen as fd.
30pub type RawFd = HANDLE;
31
32/// Extracts raw fds.
33pub trait AsRawFd {
34    /// Extracts the raw fd.
35    fn as_raw_fd(&self) -> RawFd;
36}
37
38/// Owned handle or socket on Windows.
39#[derive(Debug)]
40pub enum OwnedFd {
41    /// Win32 handle.
42    File(OwnedHandle),
43    /// Windows socket handle.
44    Socket(OwnedSocket),
45}
46
47impl AsRawFd for OwnedFd {
48    fn as_raw_fd(&self) -> RawFd {
49        match self {
50            Self::File(fd) => fd.as_raw_handle() as _,
51            Self::Socket(s) => s.as_raw_socket() as _,
52        }
53    }
54}
55
56impl AsRawFd for RawFd {
57    fn as_raw_fd(&self) -> RawFd {
58        *self
59    }
60}
61
62impl AsRawFd for std::fs::File {
63    fn as_raw_fd(&self) -> RawFd {
64        self.as_raw_handle() as _
65    }
66}
67
68impl AsRawFd for OwnedHandle {
69    fn as_raw_fd(&self) -> RawFd {
70        self.as_raw_handle() as _
71    }
72}
73
74impl AsRawFd for socket2::Socket {
75    fn as_raw_fd(&self) -> RawFd {
76        self.as_raw_socket() as _
77    }
78}
79
80impl AsRawFd for OwnedSocket {
81    fn as_raw_fd(&self) -> RawFd {
82        self.as_raw_socket() as _
83    }
84}
85
86impl AsRawFd for std::process::ChildStdin {
87    fn as_raw_fd(&self) -> RawFd {
88        self.as_raw_handle() as _
89    }
90}
91
92impl AsRawFd for std::process::ChildStdout {
93    fn as_raw_fd(&self) -> RawFd {
94        self.as_raw_handle() as _
95    }
96}
97
98impl AsRawFd for std::process::ChildStderr {
99    fn as_raw_fd(&self) -> RawFd {
100        self.as_raw_handle() as _
101    }
102}
103
104impl From<OwnedHandle> for OwnedFd {
105    fn from(value: OwnedHandle) -> Self {
106        Self::File(value)
107    }
108}
109
110impl From<std::fs::File> for OwnedFd {
111    fn from(value: std::fs::File) -> Self {
112        Self::File(OwnedHandle::from(value))
113    }
114}
115
116impl From<std::process::ChildStdin> for OwnedFd {
117    fn from(value: std::process::ChildStdin) -> Self {
118        Self::File(OwnedHandle::from(value))
119    }
120}
121
122impl From<std::process::ChildStdout> for OwnedFd {
123    fn from(value: std::process::ChildStdout) -> Self {
124        Self::File(OwnedHandle::from(value))
125    }
126}
127
128impl From<std::process::ChildStderr> for OwnedFd {
129    fn from(value: std::process::ChildStderr) -> Self {
130        Self::File(OwnedHandle::from(value))
131    }
132}
133
134impl From<OwnedSocket> for OwnedFd {
135    fn from(value: OwnedSocket) -> Self {
136        Self::Socket(value)
137    }
138}
139
140impl From<socket2::Socket> for OwnedFd {
141    fn from(value: socket2::Socket) -> Self {
142        Self::Socket(OwnedSocket::from(value))
143    }
144}
145
146/// Borrowed handle or socket on Windows.
147#[derive(Debug)]
148pub enum BorrowedFd<'a> {
149    /// Win32 handle.
150    File(BorrowedHandle<'a>),
151    /// Windows socket handle.
152    Socket(BorrowedSocket<'a>),
153}
154
155impl AsRawFd for BorrowedFd<'_> {
156    fn as_raw_fd(&self) -> RawFd {
157        match self {
158            Self::File(fd) => fd.as_raw_handle() as RawFd,
159            Self::Socket(s) => s.as_raw_socket() as RawFd,
160        }
161    }
162}
163
164impl<'a> From<BorrowedHandle<'a>> for BorrowedFd<'a> {
165    fn from(value: BorrowedHandle<'a>) -> Self {
166        Self::File(value)
167    }
168}
169
170impl<'a> From<BorrowedSocket<'a>> for BorrowedFd<'a> {
171    fn from(value: BorrowedSocket<'a>) -> Self {
172        Self::Socket(value)
173    }
174}
175
176/// Extracts fds.
177pub trait AsFd {
178    /// Extracts the borrowed fd.
179    fn as_fd(&self) -> BorrowedFd<'_>;
180}
181
182impl AsFd for OwnedFd {
183    fn as_fd(&self) -> BorrowedFd<'_> {
184        match self {
185            Self::File(fd) => fd.as_fd(),
186            Self::Socket(s) => s.as_fd(),
187        }
188    }
189}
190
191impl AsFd for std::fs::File {
192    fn as_fd(&self) -> BorrowedFd<'_> {
193        self.as_handle().into()
194    }
195}
196
197impl AsFd for OwnedHandle {
198    fn as_fd(&self) -> BorrowedFd<'_> {
199        self.as_handle().into()
200    }
201}
202
203impl AsFd for socket2::Socket {
204    fn as_fd(&self) -> BorrowedFd<'_> {
205        self.as_socket().into()
206    }
207}
208
209impl AsFd for OwnedSocket {
210    fn as_fd(&self) -> BorrowedFd<'_> {
211        self.as_socket().into()
212    }
213}
214
215impl AsFd for std::process::ChildStdin {
216    fn as_fd(&self) -> BorrowedFd<'_> {
217        self.as_handle().into()
218    }
219}
220
221impl AsFd for std::process::ChildStdout {
222    fn as_fd(&self) -> BorrowedFd<'_> {
223        self.as_handle().into()
224    }
225}
226
227impl AsFd for std::process::ChildStderr {
228    fn as_fd(&self) -> BorrowedFd<'_> {
229        self.as_handle().into()
230    }
231}
232
233/// Operation type.
234pub enum OpType {
235    /// An overlapped operation.
236    Overlapped,
237    /// A blocking operation, needs a thread to spawn. The `operate` method
238    /// should be thread safe.
239    Blocking,
240    /// A Win32 event object to be waited. The user should ensure that the
241    /// handle is valid till operation completes. The `operate` method should be
242    /// thread safe.
243    Event(RawFd),
244}
245
246/// Abstraction of IOCP operations.
247pub trait OpCode {
248    /// Determines that the operation is really overlapped defined by Windows
249    /// API. If not, the driver will try to operate it in another thread.
250    fn op_type(&self) -> OpType {
251        OpType::Overlapped
252    }
253
254    /// Perform Windows API call with given pointer to overlapped struct.
255    ///
256    /// It is always safe to cast `optr` to a pointer to
257    /// [`Overlapped<Self>`].
258    ///
259    /// Don't do heavy work here if [`OpCode::op_type`] returns
260    /// [`OpType::Event`].
261    ///
262    /// # Safety
263    ///
264    /// * `self` must be alive until the operation completes.
265    /// * When [`OpCode::op_type`] returns [`OpType::Blocking`], this method is
266    ///   called in another thread.
267    unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
268
269    /// Cancel the async IO operation.
270    ///
271    /// Usually it calls `CancelIoEx`.
272    ///
273    /// # Safety
274    ///
275    /// * Should not use [`Overlapped::op`].
276    unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
277        let _optr = optr; // ignore it
278        Ok(())
279    }
280}
281
282/// Low-level driver of IOCP.
283pub(crate) struct Driver {
284    port: cp::Port,
285    waits: HashMap<usize, wait::Wait>,
286    pool: AsyncifyPool,
287    notify_overlapped: Arc<Overlapped>,
288}
289
290impl Driver {
291    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
292        instrument!(compio_log::Level::TRACE, "new", ?builder);
293
294        let port = cp::Port::new()?;
295        let driver = port.as_raw_handle() as _;
296        Ok(Self {
297            port,
298            waits: HashMap::default(),
299            pool: builder.create_or_get_thread_pool(),
300            notify_overlapped: Arc::new(Overlapped::new(driver)),
301        })
302    }
303
304    pub fn driver_type(&self) -> DriverType {
305        DriverType::IOCP
306    }
307
308    pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
309        Key::new(self.port.as_raw_handle() as _, op)
310    }
311
312    pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
313        self.port.attach(fd)
314    }
315
316    pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
317        instrument!(compio_log::Level::TRACE, "cancel", ?op);
318        trace!("cancel RawOp");
319        let overlapped_ptr = op.as_mut_ptr();
320        if let Some(w) = self.waits.get_mut(&op.user_data()) {
321            if w.cancel().is_ok() {
322                // The pack has been cancelled successfully, which means no packet will be post
323                // to IOCP. Need not set the result because `create_entry` handles it.
324                self.port.post_raw(overlapped_ptr).ok();
325            }
326        }
327        let op = op.as_op_pin();
328        // It's OK to fail to cancel.
329        trace!("call OpCode::cancel");
330        unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
331    }
332
333    pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
334        instrument!(compio_log::Level::TRACE, "push", ?op);
335        let user_data = op.user_data();
336        trace!("push RawOp");
337        let optr = op.as_mut_ptr();
338        let op_pin = op.as_op_pin();
339        match op_pin.op_type() {
340            OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
341            OpType::Blocking => loop {
342                if self.push_blocking(user_data) {
343                    break Poll::Pending;
344                } else {
345                    // It's OK to wait forever, because any blocking task will notify the IOCP after
346                    // it completes.
347                    unsafe {
348                        self.poll(None)?;
349                    }
350                }
351            },
352            OpType::Event(e) => {
353                self.waits
354                    .insert(user_data, wait::Wait::new(&self.port, e, op)?);
355                Poll::Pending
356            }
357        }
358    }
359
360    fn push_blocking(&mut self, user_data: usize) -> bool {
361        let port = self.port.handle();
362        self.pool
363            .dispatch(move || {
364                let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
365                let optr = op.as_mut_ptr();
366                let res = op.operate_blocking();
367                port.post(res, optr).ok();
368            })
369            .is_ok()
370    }
371
372    fn create_entry(
373        notify_user_data: usize,
374        waits: &mut HashMap<usize, wait::Wait>,
375        entry: Entry,
376    ) -> Option<Entry> {
377        let user_data = entry.user_data();
378        if user_data != notify_user_data {
379            if let Some(w) = waits.remove(&user_data) {
380                if w.is_cancelled() {
381                    Some(Entry::new(
382                        user_data,
383                        Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
384                    ))
385                } else if entry.result.is_err() {
386                    Some(entry)
387                } else {
388                    let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
389                    let result = op.operate_blocking();
390                    Some(Entry::new(user_data, result))
391                }
392            } else {
393                Some(entry)
394            }
395        } else {
396            None
397        }
398    }
399
400    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
401        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
402
403        let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;
404
405        for e in self.port.poll(timeout)? {
406            if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
407                e.notify();
408            }
409        }
410
411        Ok(())
412    }
413
414    pub fn handle(&self) -> NotifyHandle {
415        NotifyHandle::new(self.port.handle(), self.notify_overlapped.clone())
416    }
417
418    pub fn create_buffer_pool(
419        &mut self,
420        buffer_len: u16,
421        buffer_size: usize,
422    ) -> io::Result<BufferPool> {
423        Ok(BufferPool::new(buffer_len, buffer_size))
424    }
425
426    /// # Safety
427    ///
428    /// caller must make sure release the buffer pool with correct driver
429    pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
430        Ok(())
431    }
432}
433
434impl AsRawFd for Driver {
435    fn as_raw_fd(&self) -> RawFd {
436        self.port.as_raw_handle() as _
437    }
438}
439
440/// A notify handle to the inner driver.
441pub struct NotifyHandle {
442    port: cp::PortHandle,
443    overlapped: Arc<Overlapped>,
444}
445
446impl NotifyHandle {
447    fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
448        Self { port, overlapped }
449    }
450
451    /// Notify the inner driver.
452    pub fn notify(&self) -> io::Result<()> {
453        self.port.post_raw(self.overlapped.as_ref())
454    }
455}
456
457/// The overlapped struct we actually used for IOCP.
458#[repr(C)]
459pub struct Overlapped {
460    /// The base [`OVERLAPPED`].
461    pub base: OVERLAPPED,
462    /// The unique ID of created driver.
463    pub driver: RawFd,
464}
465
466impl Overlapped {
467    pub(crate) fn new(driver: RawFd) -> Self {
468        Self {
469            base: unsafe { std::mem::zeroed() },
470            driver,
471        }
472    }
473}
474
475// SAFETY: neither field of `OVERLAPPED` is used
476unsafe impl Send for Overlapped {}
477unsafe impl Sync for Overlapped {}