Skip to main content

compio_driver/sys/iocp/
mod.rs

1use std::{
2    collections::HashMap,
3    io,
4    os::windows::io::{
5        AsHandle, AsRawHandle, AsRawSocket, AsSocket, BorrowedHandle, BorrowedSocket, OwnedHandle,
6        OwnedSocket,
7    },
8    pin::Pin,
9    sync::Arc,
10    task::{Poll, Wake, Waker},
11    time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{
16    Foundation::{ERROR_CANCELLED, HANDLE},
17    System::IO::OVERLAPPED,
18};
19
20use crate::{
21    AsyncifyPool, BufferPool, DriverType, Entry, ErasedKey, Key, ProactorBuilder, key::RefExt,
22};
23
24pub(crate) mod op;
25
26mod cp;
27mod wait;
28
29/// Extra data attached for IOCP.
30#[repr(C)]
31pub struct Extra {
32    overlapped: Overlapped,
33}
34
35impl Default for Extra {
36    fn default() -> Self {
37        Self {
38            overlapped: Overlapped::new(std::ptr::null_mut()),
39        }
40    }
41}
42
43impl Extra {
44    pub(crate) fn new(driver: RawFd) -> Self {
45        Self {
46            overlapped: Overlapped::new(driver),
47        }
48    }
49}
50
51impl super::Extra {
52    pub(crate) fn optr(&mut self) -> *mut Overlapped {
53        &mut self.0.overlapped as _
54    }
55}
56
57/// On windows, handle and socket are in the same size.
58/// Both of them could be attached to an IOCP.
59/// Therefore, both could be seen as fd.
60pub type RawFd = HANDLE;
61
62/// Extracts raw fds.
63pub trait AsRawFd {
64    /// Extracts the raw fd.
65    fn as_raw_fd(&self) -> RawFd;
66}
67
68/// Owned handle or socket on Windows.
69#[derive(Debug)]
70pub enum OwnedFd {
71    /// Win32 handle.
72    File(OwnedHandle),
73    /// Windows socket handle.
74    Socket(OwnedSocket),
75}
76
77impl AsRawFd for OwnedFd {
78    fn as_raw_fd(&self) -> RawFd {
79        match self {
80            Self::File(fd) => fd.as_raw_handle() as _,
81            Self::Socket(s) => s.as_raw_socket() as _,
82        }
83    }
84}
85
86impl AsRawFd for RawFd {
87    fn as_raw_fd(&self) -> RawFd {
88        *self
89    }
90}
91
92impl AsRawFd for std::fs::File {
93    fn as_raw_fd(&self) -> RawFd {
94        self.as_raw_handle() as _
95    }
96}
97
98impl AsRawFd for OwnedHandle {
99    fn as_raw_fd(&self) -> RawFd {
100        self.as_raw_handle() as _
101    }
102}
103
104impl AsRawFd for socket2::Socket {
105    fn as_raw_fd(&self) -> RawFd {
106        self.as_raw_socket() as _
107    }
108}
109
110impl AsRawFd for OwnedSocket {
111    fn as_raw_fd(&self) -> RawFd {
112        self.as_raw_socket() as _
113    }
114}
115
116impl AsRawFd for std::process::ChildStdin {
117    fn as_raw_fd(&self) -> RawFd {
118        self.as_raw_handle() as _
119    }
120}
121
122impl AsRawFd for std::process::ChildStdout {
123    fn as_raw_fd(&self) -> RawFd {
124        self.as_raw_handle() as _
125    }
126}
127
128impl AsRawFd for std::process::ChildStderr {
129    fn as_raw_fd(&self) -> RawFd {
130        self.as_raw_handle() as _
131    }
132}
133
134impl From<OwnedHandle> for OwnedFd {
135    fn from(value: OwnedHandle) -> Self {
136        Self::File(value)
137    }
138}
139
140impl From<std::fs::File> for OwnedFd {
141    fn from(value: std::fs::File) -> Self {
142        Self::File(OwnedHandle::from(value))
143    }
144}
145
146impl From<std::process::ChildStdin> for OwnedFd {
147    fn from(value: std::process::ChildStdin) -> Self {
148        Self::File(OwnedHandle::from(value))
149    }
150}
151
152impl From<std::process::ChildStdout> for OwnedFd {
153    fn from(value: std::process::ChildStdout) -> Self {
154        Self::File(OwnedHandle::from(value))
155    }
156}
157
158impl From<std::process::ChildStderr> for OwnedFd {
159    fn from(value: std::process::ChildStderr) -> Self {
160        Self::File(OwnedHandle::from(value))
161    }
162}
163
164impl From<OwnedSocket> for OwnedFd {
165    fn from(value: OwnedSocket) -> Self {
166        Self::Socket(value)
167    }
168}
169
170impl From<socket2::Socket> for OwnedFd {
171    fn from(value: socket2::Socket) -> Self {
172        Self::Socket(OwnedSocket::from(value))
173    }
174}
175
176/// Borrowed handle or socket on Windows.
177#[derive(Debug)]
178pub enum BorrowedFd<'a> {
179    /// Win32 handle.
180    File(BorrowedHandle<'a>),
181    /// Windows socket handle.
182    Socket(BorrowedSocket<'a>),
183}
184
185impl AsRawFd for BorrowedFd<'_> {
186    fn as_raw_fd(&self) -> RawFd {
187        match self {
188            Self::File(fd) => fd.as_raw_handle() as RawFd,
189            Self::Socket(s) => s.as_raw_socket() as RawFd,
190        }
191    }
192}
193
194impl<'a> From<BorrowedHandle<'a>> for BorrowedFd<'a> {
195    fn from(value: BorrowedHandle<'a>) -> Self {
196        Self::File(value)
197    }
198}
199
200impl<'a> From<BorrowedSocket<'a>> for BorrowedFd<'a> {
201    fn from(value: BorrowedSocket<'a>) -> Self {
202        Self::Socket(value)
203    }
204}
205
206/// Extracts fds.
207pub trait AsFd {
208    /// Extracts the borrowed fd.
209    fn as_fd(&self) -> BorrowedFd<'_>;
210}
211
212impl AsFd for OwnedFd {
213    fn as_fd(&self) -> BorrowedFd<'_> {
214        match self {
215            Self::File(fd) => fd.as_fd(),
216            Self::Socket(s) => s.as_fd(),
217        }
218    }
219}
220
221impl AsFd for std::fs::File {
222    fn as_fd(&self) -> BorrowedFd<'_> {
223        self.as_handle().into()
224    }
225}
226
227impl AsFd for OwnedHandle {
228    fn as_fd(&self) -> BorrowedFd<'_> {
229        self.as_handle().into()
230    }
231}
232
233impl AsFd for socket2::Socket {
234    fn as_fd(&self) -> BorrowedFd<'_> {
235        self.as_socket().into()
236    }
237}
238
239impl AsFd for OwnedSocket {
240    fn as_fd(&self) -> BorrowedFd<'_> {
241        self.as_socket().into()
242    }
243}
244
245impl AsFd for std::process::ChildStdin {
246    fn as_fd(&self) -> BorrowedFd<'_> {
247        self.as_handle().into()
248    }
249}
250
251impl AsFd for std::process::ChildStdout {
252    fn as_fd(&self) -> BorrowedFd<'_> {
253        self.as_handle().into()
254    }
255}
256
257impl AsFd for std::process::ChildStderr {
258    fn as_fd(&self) -> BorrowedFd<'_> {
259        self.as_handle().into()
260    }
261}
262
263/// Operation type.
264pub enum OpType {
265    /// An overlapped operation.
266    Overlapped,
267    /// A blocking operation, needs a thread to spawn. The `operate` method
268    /// should be thread safe.
269    Blocking,
270    /// A Win32 event object to be waited. The user should ensure that the
271    /// handle is valid till operation completes. The `operate` method should be
272    /// thread safe.
273    Event(RawFd),
274}
275
276/// Abstraction of IOCP operations.
277///
278/// # Safety
279///
280/// Implementors must ensure that the operation is safe to be polled
281/// according to the returned [`OpType`].
282pub unsafe trait OpCode {
283    /// Determines that the operation is really overlapped defined by Windows
284    /// API. If not, the driver will try to operate it in another thread.
285    fn op_type(&self) -> OpType {
286        OpType::Overlapped
287    }
288
289    /// Perform Windows API call with given pointer to overlapped struct.
290    ///
291    /// It is always safe to cast `optr` to a pointer to
292    /// [`Overlapped<Self>`].
293    ///
294    /// Don't do heavy work here if [`OpCode::op_type`] returns
295    /// [`OpType::Event`].
296    ///
297    /// # Safety
298    ///
299    /// * `self` must be alive until the operation completes.
300    /// * When [`OpCode::op_type`] returns [`OpType::Blocking`], this method is
301    ///   called in another thread.
302    unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
303
304    /// Cancel the async IO operation.
305    ///
306    /// Usually it calls `CancelIoEx`.
307    // # Safety for implementors
308    //
309    // `optr` must not be dereferenced. It's only used as a marker to identify the
310    // operation.
311    fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
312        let _optr = optr; // ignore it
313        Ok(())
314    }
315}
316
317/// Low-level driver of IOCP.
318pub(crate) struct Driver {
319    notify: Arc<Notify>,
320    waits: HashMap<usize, wait::Wait>,
321    pool: AsyncifyPool,
322}
323
324impl Driver {
325    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
326        instrument!(compio_log::Level::TRACE, "new", ?builder);
327
328        let port = cp::Port::new()?;
329        let driver = port.as_raw_handle() as _;
330        let overlapped = Overlapped::new(driver);
331        let notify = Arc::new(Notify::new(port, overlapped));
332        Ok(Self {
333            notify,
334            waits: HashMap::default(),
335            pool: builder.create_or_get_thread_pool(),
336        })
337    }
338
339    pub fn driver_type(&self) -> DriverType {
340        DriverType::IOCP
341    }
342
343    fn port(&self) -> &cp::Port {
344        &self.notify.port
345    }
346
347    pub fn default_extra(&self) -> Extra {
348        Extra::new(self.port().as_raw_handle() as _)
349    }
350
351    pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
352        self.port().attach(fd)
353    }
354
355    pub fn cancel<T>(&mut self, key: Key<T>) {
356        instrument!(compio_log::Level::TRACE, "cancel", ?key);
357        trace!("cancel RawOp");
358        let optr = key.borrow().extra_mut().optr();
359        if let Some(w) = self.waits.get_mut(&key.as_raw())
360            && w.cancel().is_ok()
361        {
362            // The pack has been cancelled successfully, which means no packet will be post
363            // to IOCP. Need not set the result because `create_entry` handles it.
364            self.port().post_raw(optr).ok();
365        }
366        trace!("call OpCode::cancel");
367        // It's OK to fail to cancel.
368        key.borrow().pinned_op().cancel(optr.cast()).ok();
369    }
370
371    pub fn push(&mut self, key: ErasedKey) -> Poll<io::Result<usize>> {
372        instrument!(compio_log::Level::TRACE, "push", ?key);
373        trace!("push RawOp");
374        let mut op = key.borrow();
375        let optr = op.extra_mut().optr();
376        let pinned = op.pinned_op();
377        let op_type = pinned.op_type();
378        match op_type {
379            OpType::Overlapped => unsafe {
380                let res = pinned.operate(optr.cast());
381                drop(op);
382                if res.is_pending() {
383                    key.into_raw();
384                }
385                res
386            },
387            OpType::Blocking => {
388                drop(op);
389                self.push_blocking(key);
390                Poll::Pending
391            }
392            OpType::Event(e) => {
393                drop(op);
394                self.waits
395                    .insert(key.as_raw(), wait::Wait::new(self.notify.clone(), e, key)?);
396                Poll::Pending
397            }
398        }
399    }
400
401    fn push_blocking(&mut self, key: ErasedKey) {
402        let notify = self.notify.clone();
403        // SAFETY: we're submitting into the driver, so it's safe to freeze here.
404        let mut key = unsafe { key.freeze() };
405
406        let mut closure = move || {
407            let op = key.as_mut();
408            let res = op.operate_blocking();
409            let optr = op.extra_mut().optr();
410            // key will be unfronzen in `create_entry` when the result is ready
411            notify.port.post(res, optr).ok();
412        };
413
414        while let Err(e) = self.pool.dispatch(closure) {
415            closure = e.0;
416        }
417    }
418
419    fn create_entry(
420        notify: *const Overlapped,
421        waits: &mut HashMap<usize, wait::Wait>,
422        entry: cp::RawEntry,
423    ) -> Option<Entry> {
424        // Ignore existing entries
425        if entry.overlapped.cast_const() == notify {
426            return None;
427        }
428
429        let entry = Entry::new(
430            unsafe { ErasedKey::from_optr(entry.overlapped) },
431            entry.result,
432        );
433
434        // if there's no wait, just return the entry
435        let Some(w) = waits.remove(&entry.user_data()) else {
436            return Some(entry);
437        };
438
439        let entry = if w.is_cancelled() {
440            Entry::new(
441                entry.into_key(),
442                Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
443            )
444        } else if entry.result.is_err() {
445            entry
446        } else {
447            let key = entry.into_key();
448            let result = key.borrow().operate_blocking();
449            Entry::new(key, result)
450        };
451
452        Some(entry)
453    }
454
455    pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
456        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
457
458        let notify = &self.notify.overlapped as *const Overlapped;
459
460        for e in self.notify.port.poll(timeout)? {
461            if let Some(e) = Self::create_entry(notify, &mut self.waits, e) {
462                e.notify()
463            }
464        }
465
466        Ok(())
467    }
468
469    pub fn waker(&self) -> Waker {
470        Waker::from(self.notify.clone())
471    }
472
473    pub fn create_buffer_pool(
474        &mut self,
475        buffer_len: u16,
476        buffer_size: usize,
477    ) -> io::Result<BufferPool> {
478        Ok(BufferPool::new(buffer_len, buffer_size))
479    }
480
481    /// # Safety
482    ///
483    /// caller must make sure release the buffer pool with correct driver
484    pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
485        Ok(())
486    }
487}
488
489impl AsRawFd for Driver {
490    fn as_raw_fd(&self) -> RawFd {
491        self.port().as_raw_handle() as _
492    }
493}
494
495/// A notify handle to the inner driver.
496pub(crate) struct Notify {
497    port: cp::Port,
498    overlapped: Overlapped,
499}
500
501impl Notify {
502    fn new(port: cp::Port, overlapped: Overlapped) -> Self {
503        Self { port, overlapped }
504    }
505
506    /// Notify the inner driver.
507    pub fn notify(&self) -> io::Result<()> {
508        self.port.post_raw(&self.overlapped)
509    }
510}
511
512impl Wake for Notify {
513    fn wake(self: Arc<Self>) {
514        self.wake_by_ref();
515    }
516
517    fn wake_by_ref(self: &Arc<Self>) {
518        self.notify().ok();
519    }
520}
521
522/// The overlapped struct we actually used for IOCP.
523#[repr(C)]
524pub struct Overlapped {
525    /// The base [`OVERLAPPED`].
526    pub base: OVERLAPPED,
527    /// The unique ID of created driver.
528    pub driver: RawFd,
529}
530
531impl Overlapped {
532    pub(crate) fn new(driver: RawFd) -> Self {
533        Self {
534            base: unsafe { std::mem::zeroed() },
535            driver,
536        }
537    }
538}
539
540// SAFETY: neither field of `OVERLAPPED` is used
541unsafe impl Send for Overlapped {}
542unsafe impl Sync for Overlapped {}