1use std::{
2 collections::HashMap,
3 io,
4 os::windows::io::{
5 AsHandle, AsRawHandle, AsRawSocket, AsSocket, BorrowedHandle, BorrowedSocket, OwnedHandle,
6 OwnedSocket,
7 },
8 pin::Pin,
9 sync::Arc,
10 task::{Poll, Wake, Waker},
11 time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{
16 Foundation::{ERROR_CANCELLED, HANDLE},
17 System::IO::OVERLAPPED,
18};
19
20use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder};
21
22pub(crate) mod op;
23
24mod cp;
25mod wait;
26
27pub type RawFd = HANDLE;
31
32pub trait AsRawFd {
34 fn as_raw_fd(&self) -> RawFd;
36}
37
38#[derive(Debug)]
40pub enum OwnedFd {
41 File(OwnedHandle),
43 Socket(OwnedSocket),
45}
46
47impl AsRawFd for OwnedFd {
48 fn as_raw_fd(&self) -> RawFd {
49 match self {
50 Self::File(fd) => fd.as_raw_handle() as _,
51 Self::Socket(s) => s.as_raw_socket() as _,
52 }
53 }
54}
55
56impl AsRawFd for RawFd {
57 fn as_raw_fd(&self) -> RawFd {
58 *self
59 }
60}
61
62impl AsRawFd for std::fs::File {
63 fn as_raw_fd(&self) -> RawFd {
64 self.as_raw_handle() as _
65 }
66}
67
68impl AsRawFd for OwnedHandle {
69 fn as_raw_fd(&self) -> RawFd {
70 self.as_raw_handle() as _
71 }
72}
73
74impl AsRawFd for socket2::Socket {
75 fn as_raw_fd(&self) -> RawFd {
76 self.as_raw_socket() as _
77 }
78}
79
80impl AsRawFd for OwnedSocket {
81 fn as_raw_fd(&self) -> RawFd {
82 self.as_raw_socket() as _
83 }
84}
85
86impl AsRawFd for std::process::ChildStdin {
87 fn as_raw_fd(&self) -> RawFd {
88 self.as_raw_handle() as _
89 }
90}
91
92impl AsRawFd for std::process::ChildStdout {
93 fn as_raw_fd(&self) -> RawFd {
94 self.as_raw_handle() as _
95 }
96}
97
98impl AsRawFd for std::process::ChildStderr {
99 fn as_raw_fd(&self) -> RawFd {
100 self.as_raw_handle() as _
101 }
102}
103
104impl From<OwnedHandle> for OwnedFd {
105 fn from(value: OwnedHandle) -> Self {
106 Self::File(value)
107 }
108}
109
110impl From<std::fs::File> for OwnedFd {
111 fn from(value: std::fs::File) -> Self {
112 Self::File(OwnedHandle::from(value))
113 }
114}
115
116impl From<std::process::ChildStdin> for OwnedFd {
117 fn from(value: std::process::ChildStdin) -> Self {
118 Self::File(OwnedHandle::from(value))
119 }
120}
121
122impl From<std::process::ChildStdout> for OwnedFd {
123 fn from(value: std::process::ChildStdout) -> Self {
124 Self::File(OwnedHandle::from(value))
125 }
126}
127
128impl From<std::process::ChildStderr> for OwnedFd {
129 fn from(value: std::process::ChildStderr) -> Self {
130 Self::File(OwnedHandle::from(value))
131 }
132}
133
134impl From<OwnedSocket> for OwnedFd {
135 fn from(value: OwnedSocket) -> Self {
136 Self::Socket(value)
137 }
138}
139
140impl From<socket2::Socket> for OwnedFd {
141 fn from(value: socket2::Socket) -> Self {
142 Self::Socket(OwnedSocket::from(value))
143 }
144}
145
146#[derive(Debug)]
148pub enum BorrowedFd<'a> {
149 File(BorrowedHandle<'a>),
151 Socket(BorrowedSocket<'a>),
153}
154
155impl AsRawFd for BorrowedFd<'_> {
156 fn as_raw_fd(&self) -> RawFd {
157 match self {
158 Self::File(fd) => fd.as_raw_handle() as RawFd,
159 Self::Socket(s) => s.as_raw_socket() as RawFd,
160 }
161 }
162}
163
164impl<'a> From<BorrowedHandle<'a>> for BorrowedFd<'a> {
165 fn from(value: BorrowedHandle<'a>) -> Self {
166 Self::File(value)
167 }
168}
169
170impl<'a> From<BorrowedSocket<'a>> for BorrowedFd<'a> {
171 fn from(value: BorrowedSocket<'a>) -> Self {
172 Self::Socket(value)
173 }
174}
175
176pub trait AsFd {
178 fn as_fd(&self) -> BorrowedFd<'_>;
180}
181
182impl AsFd for OwnedFd {
183 fn as_fd(&self) -> BorrowedFd<'_> {
184 match self {
185 Self::File(fd) => fd.as_fd(),
186 Self::Socket(s) => s.as_fd(),
187 }
188 }
189}
190
191impl AsFd for std::fs::File {
192 fn as_fd(&self) -> BorrowedFd<'_> {
193 self.as_handle().into()
194 }
195}
196
197impl AsFd for OwnedHandle {
198 fn as_fd(&self) -> BorrowedFd<'_> {
199 self.as_handle().into()
200 }
201}
202
203impl AsFd for socket2::Socket {
204 fn as_fd(&self) -> BorrowedFd<'_> {
205 self.as_socket().into()
206 }
207}
208
209impl AsFd for OwnedSocket {
210 fn as_fd(&self) -> BorrowedFd<'_> {
211 self.as_socket().into()
212 }
213}
214
215impl AsFd for std::process::ChildStdin {
216 fn as_fd(&self) -> BorrowedFd<'_> {
217 self.as_handle().into()
218 }
219}
220
221impl AsFd for std::process::ChildStdout {
222 fn as_fd(&self) -> BorrowedFd<'_> {
223 self.as_handle().into()
224 }
225}
226
227impl AsFd for std::process::ChildStderr {
228 fn as_fd(&self) -> BorrowedFd<'_> {
229 self.as_handle().into()
230 }
231}
232
233pub enum OpType {
235 Overlapped,
237 Blocking,
240 Event(RawFd),
244}
245
246pub trait OpCode {
248 fn op_type(&self) -> OpType {
251 OpType::Overlapped
252 }
253
254 unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
268
269 unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
277 let _optr = optr; Ok(())
279 }
280}
281
282pub(crate) struct Driver {
284 notify: Arc<Notify>,
285 waits: HashMap<usize, wait::Wait>,
286 pool: AsyncifyPool,
287}
288
289impl Driver {
290 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
291 instrument!(compio_log::Level::TRACE, "new", ?builder);
292
293 let port = cp::Port::new()?;
294 let driver = port.as_raw_handle() as _;
295 let overlapped = Overlapped::new(driver);
296 let notify = Arc::new(Notify::new(port, overlapped));
297 Ok(Self {
298 notify,
299 waits: HashMap::default(),
300 pool: builder.create_or_get_thread_pool(),
301 })
302 }
303
304 pub fn driver_type(&self) -> DriverType {
305 DriverType::IOCP
306 }
307
308 fn port(&self) -> &cp::Port {
309 &self.notify.port
310 }
311
312 pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
313 Key::new(self.port().as_raw_handle() as _, op)
314 }
315
316 pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
317 self.port().attach(fd)
318 }
319
320 pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
321 instrument!(compio_log::Level::TRACE, "cancel", ?op);
322 trace!("cancel RawOp");
323 let overlapped_ptr = op.as_mut_ptr();
324 if let Some(w) = self.waits.get_mut(&op.user_data())
325 && w.cancel().is_ok()
326 {
327 self.port().post_raw(overlapped_ptr).ok();
330 }
331 let op = op.as_op_pin();
332 trace!("call OpCode::cancel");
334 unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
335 }
336
337 pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
338 instrument!(compio_log::Level::TRACE, "push", ?op);
339 let user_data = op.user_data();
340 trace!("push RawOp");
341 let optr = op.as_mut_ptr();
342 let op_pin = op.as_op_pin();
343 match op_pin.op_type() {
344 OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
345 OpType::Blocking => loop {
346 if self.push_blocking(user_data) {
347 break Poll::Pending;
348 } else {
349 self.poll(None)?;
352 }
353 },
354 OpType::Event(e) => {
355 self.waits
356 .insert(user_data, wait::Wait::new(self.notify.clone(), e, op)?);
357 Poll::Pending
358 }
359 }
360 }
361
362 fn push_blocking(&mut self, user_data: usize) -> bool {
363 let notify = self.notify.clone();
364 self.pool
365 .dispatch(move || {
366 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
367 let optr = op.as_mut_ptr();
368 let res = op.operate_blocking();
369 notify.port.post(res, optr).ok();
370 })
371 .is_ok()
372 }
373
374 fn create_entry(
375 notify_user_data: usize,
376 waits: &mut HashMap<usize, wait::Wait>,
377 entry: Entry,
378 ) -> Option<Entry> {
379 let user_data = entry.user_data();
380 if user_data != notify_user_data {
381 if let Some(w) = waits.remove(&user_data) {
382 if w.is_cancelled() {
383 Some(Entry::new(
384 user_data,
385 Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
386 ))
387 } else if entry.result.is_err() {
388 Some(entry)
389 } else {
390 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
391 let result = op.operate_blocking();
392 Some(Entry::new(user_data, result))
393 }
394 } else {
395 Some(entry)
396 }
397 } else {
398 None
399 }
400 }
401
402 pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
403 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
404
405 let notify_user_data = &self.notify.overlapped as *const Overlapped as usize;
406
407 for e in self.notify.port.poll(timeout)? {
408 if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
409 unsafe { e.notify() }
411 }
412 }
413
414 Ok(())
415 }
416
417 pub fn waker(&self) -> Waker {
418 Waker::from(self.notify.clone())
419 }
420
421 pub fn create_buffer_pool(
422 &mut self,
423 buffer_len: u16,
424 buffer_size: usize,
425 ) -> io::Result<BufferPool> {
426 Ok(BufferPool::new(buffer_len, buffer_size))
427 }
428
429 pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
433 Ok(())
434 }
435}
436
437impl AsRawFd for Driver {
438 fn as_raw_fd(&self) -> RawFd {
439 self.port().as_raw_handle() as _
440 }
441}
442
443struct Notify {
445 port: cp::Port,
446 overlapped: Overlapped,
447}
448
449impl Notify {
450 fn new(port: cp::Port, overlapped: Overlapped) -> Self {
451 Self { port, overlapped }
452 }
453
454 pub fn notify(&self) -> io::Result<()> {
456 self.port.post_raw(&self.overlapped)
457 }
458}
459
460impl Wake for Notify {
461 fn wake(self: Arc<Self>) {
462 self.wake_by_ref();
463 }
464
465 fn wake_by_ref(self: &Arc<Self>) {
466 self.notify().ok();
467 }
468}
469
470#[repr(C)]
472pub struct Overlapped {
473 pub base: OVERLAPPED,
475 pub driver: RawFd,
477}
478
479impl Overlapped {
480 pub(crate) fn new(driver: RawFd) -> Self {
481 Self {
482 base: unsafe { std::mem::zeroed() },
483 driver,
484 }
485 }
486}
487
488unsafe impl Send for Overlapped {}
490unsafe impl Sync for Overlapped {}