compio_driver/iocp/
mod.rs

1use std::{
2    collections::HashMap,
3    io,
4    os::windows::io::{
5        AsHandle, AsRawHandle, AsRawSocket, AsSocket, BorrowedHandle, BorrowedSocket, OwnedHandle,
6        OwnedSocket,
7    },
8    pin::Pin,
9    sync::Arc,
10    task::Poll,
11    time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{Foundation::ERROR_CANCELLED, System::IO::OVERLAPPED};
16
17use crate::{AsyncifyPool, BufferPool, Entry, Key, ProactorBuilder};
18
19pub(crate) mod op;
20
21mod cp;
22mod wait;
23
24pub(crate) use windows_sys::Win32::Networking::WinSock::{
25    SOCKADDR_STORAGE as sockaddr_storage, socklen_t,
26};
27
28/// On windows, handle and socket are in the same size.
29/// Both of them could be attached to an IOCP.
30/// Therefore, both could be seen as fd.
31pub type RawFd = isize;
32
33/// Extracts raw fds.
34pub trait AsRawFd {
35    /// Extracts the raw fd.
36    fn as_raw_fd(&self) -> RawFd;
37}
38
39/// Owned handle or socket on Windows.
40#[derive(Debug)]
41pub enum OwnedFd {
42    /// Win32 handle.
43    File(OwnedHandle),
44    /// Windows socket handle.
45    Socket(OwnedSocket),
46}
47
48impl AsRawFd for OwnedFd {
49    fn as_raw_fd(&self) -> RawFd {
50        match self {
51            Self::File(fd) => fd.as_raw_handle() as _,
52            Self::Socket(s) => s.as_raw_socket() as _,
53        }
54    }
55}
56
57impl AsRawFd for RawFd {
58    fn as_raw_fd(&self) -> RawFd {
59        *self
60    }
61}
62
63impl AsRawFd for std::fs::File {
64    fn as_raw_fd(&self) -> RawFd {
65        self.as_raw_handle() as _
66    }
67}
68
69impl AsRawFd for OwnedHandle {
70    fn as_raw_fd(&self) -> RawFd {
71        self.as_raw_handle() as _
72    }
73}
74
75impl AsRawFd for socket2::Socket {
76    fn as_raw_fd(&self) -> RawFd {
77        self.as_raw_socket() as _
78    }
79}
80
81impl AsRawFd for OwnedSocket {
82    fn as_raw_fd(&self) -> RawFd {
83        self.as_raw_socket() as _
84    }
85}
86
87impl AsRawFd for std::process::ChildStdin {
88    fn as_raw_fd(&self) -> RawFd {
89        self.as_raw_handle() as _
90    }
91}
92
93impl AsRawFd for std::process::ChildStdout {
94    fn as_raw_fd(&self) -> RawFd {
95        self.as_raw_handle() as _
96    }
97}
98
99impl AsRawFd for std::process::ChildStderr {
100    fn as_raw_fd(&self) -> RawFd {
101        self.as_raw_handle() as _
102    }
103}
104
105impl From<OwnedHandle> for OwnedFd {
106    fn from(value: OwnedHandle) -> Self {
107        Self::File(value)
108    }
109}
110
111impl From<std::fs::File> for OwnedFd {
112    fn from(value: std::fs::File) -> Self {
113        Self::File(OwnedHandle::from(value))
114    }
115}
116
117impl From<std::process::ChildStdin> for OwnedFd {
118    fn from(value: std::process::ChildStdin) -> Self {
119        Self::File(OwnedHandle::from(value))
120    }
121}
122
123impl From<std::process::ChildStdout> for OwnedFd {
124    fn from(value: std::process::ChildStdout) -> Self {
125        Self::File(OwnedHandle::from(value))
126    }
127}
128
129impl From<std::process::ChildStderr> for OwnedFd {
130    fn from(value: std::process::ChildStderr) -> Self {
131        Self::File(OwnedHandle::from(value))
132    }
133}
134
135impl From<OwnedSocket> for OwnedFd {
136    fn from(value: OwnedSocket) -> Self {
137        Self::Socket(value)
138    }
139}
140
141impl From<socket2::Socket> for OwnedFd {
142    fn from(value: socket2::Socket) -> Self {
143        Self::Socket(OwnedSocket::from(value))
144    }
145}
146
147/// Borrowed handle or socket on Windows.
148#[derive(Debug)]
149pub enum BorrowedFd<'a> {
150    /// Win32 handle.
151    File(BorrowedHandle<'a>),
152    /// Windows socket handle.
153    Socket(BorrowedSocket<'a>),
154}
155
156impl AsRawFd for BorrowedFd<'_> {
157    fn as_raw_fd(&self) -> RawFd {
158        match self {
159            Self::File(fd) => fd.as_raw_handle() as RawFd,
160            Self::Socket(s) => s.as_raw_socket() as RawFd,
161        }
162    }
163}
164
165impl<'a> From<BorrowedHandle<'a>> for BorrowedFd<'a> {
166    fn from(value: BorrowedHandle<'a>) -> Self {
167        Self::File(value)
168    }
169}
170
171impl<'a> From<BorrowedSocket<'a>> for BorrowedFd<'a> {
172    fn from(value: BorrowedSocket<'a>) -> Self {
173        Self::Socket(value)
174    }
175}
176
177/// Extracts fds.
178pub trait AsFd {
179    /// Extracts the borrowed fd.
180    fn as_fd(&self) -> BorrowedFd<'_>;
181}
182
183impl AsFd for OwnedFd {
184    fn as_fd(&self) -> BorrowedFd<'_> {
185        match self {
186            Self::File(fd) => fd.as_fd(),
187            Self::Socket(s) => s.as_fd(),
188        }
189    }
190}
191
192impl AsFd for std::fs::File {
193    fn as_fd(&self) -> BorrowedFd<'_> {
194        self.as_handle().into()
195    }
196}
197
198impl AsFd for OwnedHandle {
199    fn as_fd(&self) -> BorrowedFd<'_> {
200        self.as_handle().into()
201    }
202}
203
204impl AsFd for socket2::Socket {
205    fn as_fd(&self) -> BorrowedFd<'_> {
206        self.as_socket().into()
207    }
208}
209
210impl AsFd for OwnedSocket {
211    fn as_fd(&self) -> BorrowedFd<'_> {
212        self.as_socket().into()
213    }
214}
215
216impl AsFd for std::process::ChildStdin {
217    fn as_fd(&self) -> BorrowedFd<'_> {
218        self.as_handle().into()
219    }
220}
221
222impl AsFd for std::process::ChildStdout {
223    fn as_fd(&self) -> BorrowedFd<'_> {
224        self.as_handle().into()
225    }
226}
227
228impl AsFd for std::process::ChildStderr {
229    fn as_fd(&self) -> BorrowedFd<'_> {
230        self.as_handle().into()
231    }
232}
233
234/// Operation type.
235pub enum OpType {
236    /// An overlapped operation.
237    Overlapped,
238    /// A blocking operation, needs a thread to spawn. The `operate` method
239    /// should be thread safe.
240    Blocking,
241    /// A Win32 event object to be waited. The user should ensure that the
242    /// handle is valid till operation completes. The `operate` method should be
243    /// thread safe.
244    Event(RawFd),
245}
246
247/// Abstraction of IOCP operations.
248pub trait OpCode {
249    /// Determines that the operation is really overlapped defined by Windows
250    /// API. If not, the driver will try to operate it in another thread.
251    fn op_type(&self) -> OpType {
252        OpType::Overlapped
253    }
254
255    /// Perform Windows API call with given pointer to overlapped struct.
256    ///
257    /// It is always safe to cast `optr` to a pointer to
258    /// [`Overlapped<Self>`].
259    ///
260    /// Don't do heavy work here if [`OpCode::op_type`] returns
261    /// [`OpType::Event`].
262    ///
263    /// # Safety
264    ///
265    /// * `self` must be alive until the operation completes.
266    /// * When [`OpCode::op_type`] returns [`OpType::Blocking`], this method is
267    ///   called in another thread.
268    unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
269
270    /// Cancel the async IO operation.
271    ///
272    /// Usually it calls `CancelIoEx`.
273    ///
274    /// # Safety
275    ///
276    /// * Should not use [`Overlapped::op`].
277    unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
278        let _optr = optr; // ignore it
279        Ok(())
280    }
281}
282
283/// Low-level driver of IOCP.
284pub(crate) struct Driver {
285    port: cp::Port,
286    waits: HashMap<usize, wait::Wait>,
287    pool: AsyncifyPool,
288    notify_overlapped: Arc<Overlapped>,
289}
290
291impl Driver {
292    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
293        instrument!(compio_log::Level::TRACE, "new", ?builder);
294
295        let port = cp::Port::new()?;
296        let driver = port.as_raw_handle() as _;
297        Ok(Self {
298            port,
299            waits: HashMap::default(),
300            pool: builder.create_or_get_thread_pool(),
301            notify_overlapped: Arc::new(Overlapped::new(driver)),
302        })
303    }
304
305    pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
306        Key::new(self.port.as_raw_handle() as _, op)
307    }
308
309    pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
310        self.port.attach(fd)
311    }
312
313    pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
314        instrument!(compio_log::Level::TRACE, "cancel", ?op);
315        trace!("cancel RawOp");
316        let overlapped_ptr = op.as_mut_ptr();
317        if let Some(w) = self.waits.get_mut(&op.user_data()) {
318            if w.cancel().is_ok() {
319                // The pack has been cancelled successfully, which means no packet will be post
320                // to IOCP. Need not set the result because `create_entry` handles it.
321                self.port.post_raw(overlapped_ptr).ok();
322            }
323        }
324        let op = op.as_op_pin();
325        // It's OK to fail to cancel.
326        trace!("call OpCode::cancel");
327        unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
328    }
329
330    pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
331        instrument!(compio_log::Level::TRACE, "push", ?op);
332        let user_data = op.user_data();
333        trace!("push RawOp");
334        let optr = op.as_mut_ptr();
335        let op_pin = op.as_op_pin();
336        match op_pin.op_type() {
337            OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
338            OpType::Blocking => loop {
339                if self.push_blocking(user_data) {
340                    break Poll::Pending;
341                } else {
342                    // It's OK to wait forever, because any blocking task will notify the IOCP after
343                    // it completes.
344                    unsafe {
345                        self.poll(None)?;
346                    }
347                }
348            },
349            OpType::Event(e) => {
350                self.waits
351                    .insert(user_data, wait::Wait::new(&self.port, e, op)?);
352                Poll::Pending
353            }
354        }
355    }
356
357    fn push_blocking(&mut self, user_data: usize) -> bool {
358        let port = self.port.handle();
359        self.pool
360            .dispatch(move || {
361                let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
362                let optr = op.as_mut_ptr();
363                let res = op.operate_blocking();
364                port.post(res, optr).ok();
365            })
366            .is_ok()
367    }
368
369    fn create_entry(
370        notify_user_data: usize,
371        waits: &mut HashMap<usize, wait::Wait>,
372        entry: Entry,
373    ) -> Option<Entry> {
374        let user_data = entry.user_data();
375        if user_data != notify_user_data {
376            if let Some(w) = waits.remove(&user_data) {
377                if w.is_cancelled() {
378                    Some(Entry::new(
379                        user_data,
380                        Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
381                    ))
382                } else if entry.result.is_err() {
383                    Some(entry)
384                } else {
385                    let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
386                    let result = op.operate_blocking();
387                    Some(Entry::new(user_data, result))
388                }
389            } else {
390                Some(entry)
391            }
392        } else {
393            None
394        }
395    }
396
397    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
398        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
399
400        let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;
401
402        for e in self.port.poll(timeout)? {
403            if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
404                e.notify();
405            }
406        }
407
408        Ok(())
409    }
410
411    pub fn handle(&self) -> NotifyHandle {
412        NotifyHandle::new(self.port.handle(), self.notify_overlapped.clone())
413    }
414
415    pub fn create_buffer_pool(
416        &mut self,
417        buffer_len: u16,
418        buffer_size: usize,
419    ) -> io::Result<BufferPool> {
420        Ok(BufferPool::new(buffer_len, buffer_size))
421    }
422
423    /// # Safety
424    ///
425    /// caller must make sure release the buffer pool with correct driver
426    pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
427        Ok(())
428    }
429}
430
431impl AsRawFd for Driver {
432    fn as_raw_fd(&self) -> RawFd {
433        self.port.as_raw_handle() as _
434    }
435}
436
437/// A notify handle to the inner driver.
438pub struct NotifyHandle {
439    port: cp::PortHandle,
440    overlapped: Arc<Overlapped>,
441}
442
443impl NotifyHandle {
444    fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
445        Self { port, overlapped }
446    }
447
448    /// Notify the inner driver.
449    pub fn notify(&self) -> io::Result<()> {
450        self.port.post_raw(self.overlapped.as_ref())
451    }
452}
453
454/// The overlapped struct we actually used for IOCP.
455#[repr(C)]
456pub struct Overlapped {
457    /// The base [`OVERLAPPED`].
458    pub base: OVERLAPPED,
459    /// The unique ID of created driver.
460    pub driver: RawFd,
461}
462
463impl Overlapped {
464    pub(crate) fn new(driver: RawFd) -> Self {
465        Self {
466            base: unsafe { std::mem::zeroed() },
467            driver,
468        }
469    }
470}
471
472// SAFETY: neither field of `OVERLAPPED` is used
473unsafe impl Send for Overlapped {}
474unsafe impl Sync for Overlapped {}