1use std::{
2 collections::HashMap,
3 io,
4 os::windows::io::{
5 AsHandle, AsRawHandle, AsRawSocket, AsSocket, BorrowedHandle, BorrowedSocket, OwnedHandle,
6 OwnedSocket,
7 },
8 pin::Pin,
9 sync::Arc,
10 task::Poll,
11 time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{
16 Foundation::{ERROR_CANCELLED, HANDLE},
17 System::IO::OVERLAPPED,
18};
19
20use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder};
21
22pub(crate) mod op;
23
24mod cp;
25mod wait;
26
27pub type RawFd = HANDLE;
31
32pub trait AsRawFd {
34 fn as_raw_fd(&self) -> RawFd;
36}
37
38#[derive(Debug)]
40pub enum OwnedFd {
41 File(OwnedHandle),
43 Socket(OwnedSocket),
45}
46
47impl AsRawFd for OwnedFd {
48 fn as_raw_fd(&self) -> RawFd {
49 match self {
50 Self::File(fd) => fd.as_raw_handle() as _,
51 Self::Socket(s) => s.as_raw_socket() as _,
52 }
53 }
54}
55
56impl AsRawFd for RawFd {
57 fn as_raw_fd(&self) -> RawFd {
58 *self
59 }
60}
61
62impl AsRawFd for std::fs::File {
63 fn as_raw_fd(&self) -> RawFd {
64 self.as_raw_handle() as _
65 }
66}
67
68impl AsRawFd for OwnedHandle {
69 fn as_raw_fd(&self) -> RawFd {
70 self.as_raw_handle() as _
71 }
72}
73
74impl AsRawFd for socket2::Socket {
75 fn as_raw_fd(&self) -> RawFd {
76 self.as_raw_socket() as _
77 }
78}
79
80impl AsRawFd for OwnedSocket {
81 fn as_raw_fd(&self) -> RawFd {
82 self.as_raw_socket() as _
83 }
84}
85
86impl AsRawFd for std::process::ChildStdin {
87 fn as_raw_fd(&self) -> RawFd {
88 self.as_raw_handle() as _
89 }
90}
91
92impl AsRawFd for std::process::ChildStdout {
93 fn as_raw_fd(&self) -> RawFd {
94 self.as_raw_handle() as _
95 }
96}
97
98impl AsRawFd for std::process::ChildStderr {
99 fn as_raw_fd(&self) -> RawFd {
100 self.as_raw_handle() as _
101 }
102}
103
104impl From<OwnedHandle> for OwnedFd {
105 fn from(value: OwnedHandle) -> Self {
106 Self::File(value)
107 }
108}
109
110impl From<std::fs::File> for OwnedFd {
111 fn from(value: std::fs::File) -> Self {
112 Self::File(OwnedHandle::from(value))
113 }
114}
115
116impl From<std::process::ChildStdin> for OwnedFd {
117 fn from(value: std::process::ChildStdin) -> Self {
118 Self::File(OwnedHandle::from(value))
119 }
120}
121
122impl From<std::process::ChildStdout> for OwnedFd {
123 fn from(value: std::process::ChildStdout) -> Self {
124 Self::File(OwnedHandle::from(value))
125 }
126}
127
128impl From<std::process::ChildStderr> for OwnedFd {
129 fn from(value: std::process::ChildStderr) -> Self {
130 Self::File(OwnedHandle::from(value))
131 }
132}
133
134impl From<OwnedSocket> for OwnedFd {
135 fn from(value: OwnedSocket) -> Self {
136 Self::Socket(value)
137 }
138}
139
140impl From<socket2::Socket> for OwnedFd {
141 fn from(value: socket2::Socket) -> Self {
142 Self::Socket(OwnedSocket::from(value))
143 }
144}
145
146#[derive(Debug)]
148pub enum BorrowedFd<'a> {
149 File(BorrowedHandle<'a>),
151 Socket(BorrowedSocket<'a>),
153}
154
155impl AsRawFd for BorrowedFd<'_> {
156 fn as_raw_fd(&self) -> RawFd {
157 match self {
158 Self::File(fd) => fd.as_raw_handle() as RawFd,
159 Self::Socket(s) => s.as_raw_socket() as RawFd,
160 }
161 }
162}
163
164impl<'a> From<BorrowedHandle<'a>> for BorrowedFd<'a> {
165 fn from(value: BorrowedHandle<'a>) -> Self {
166 Self::File(value)
167 }
168}
169
170impl<'a> From<BorrowedSocket<'a>> for BorrowedFd<'a> {
171 fn from(value: BorrowedSocket<'a>) -> Self {
172 Self::Socket(value)
173 }
174}
175
176pub trait AsFd {
178 fn as_fd(&self) -> BorrowedFd<'_>;
180}
181
182impl AsFd for OwnedFd {
183 fn as_fd(&self) -> BorrowedFd<'_> {
184 match self {
185 Self::File(fd) => fd.as_fd(),
186 Self::Socket(s) => s.as_fd(),
187 }
188 }
189}
190
191impl AsFd for std::fs::File {
192 fn as_fd(&self) -> BorrowedFd<'_> {
193 self.as_handle().into()
194 }
195}
196
197impl AsFd for OwnedHandle {
198 fn as_fd(&self) -> BorrowedFd<'_> {
199 self.as_handle().into()
200 }
201}
202
203impl AsFd for socket2::Socket {
204 fn as_fd(&self) -> BorrowedFd<'_> {
205 self.as_socket().into()
206 }
207}
208
209impl AsFd for OwnedSocket {
210 fn as_fd(&self) -> BorrowedFd<'_> {
211 self.as_socket().into()
212 }
213}
214
215impl AsFd for std::process::ChildStdin {
216 fn as_fd(&self) -> BorrowedFd<'_> {
217 self.as_handle().into()
218 }
219}
220
221impl AsFd for std::process::ChildStdout {
222 fn as_fd(&self) -> BorrowedFd<'_> {
223 self.as_handle().into()
224 }
225}
226
227impl AsFd for std::process::ChildStderr {
228 fn as_fd(&self) -> BorrowedFd<'_> {
229 self.as_handle().into()
230 }
231}
232
233pub enum OpType {
235 Overlapped,
237 Blocking,
240 Event(RawFd),
244}
245
246pub trait OpCode {
248 fn op_type(&self) -> OpType {
251 OpType::Overlapped
252 }
253
254 unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
268
269 unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
277 let _optr = optr; Ok(())
279 }
280}
281
282pub(crate) struct Driver {
284 port: cp::Port,
285 waits: HashMap<usize, wait::Wait>,
286 pool: AsyncifyPool,
287 notify_overlapped: Arc<Overlapped>,
288}
289
290impl Driver {
291 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
292 instrument!(compio_log::Level::TRACE, "new", ?builder);
293
294 let port = cp::Port::new()?;
295 let driver = port.as_raw_handle() as _;
296 Ok(Self {
297 port,
298 waits: HashMap::default(),
299 pool: builder.create_or_get_thread_pool(),
300 notify_overlapped: Arc::new(Overlapped::new(driver)),
301 })
302 }
303
304 pub fn driver_type(&self) -> DriverType {
305 DriverType::IOCP
306 }
307
308 pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
309 Key::new(self.port.as_raw_handle() as _, op)
310 }
311
312 pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
313 self.port.attach(fd)
314 }
315
316 pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
317 instrument!(compio_log::Level::TRACE, "cancel", ?op);
318 trace!("cancel RawOp");
319 let overlapped_ptr = op.as_mut_ptr();
320 if let Some(w) = self.waits.get_mut(&op.user_data()) {
321 if w.cancel().is_ok() {
322 self.port.post_raw(overlapped_ptr).ok();
325 }
326 }
327 let op = op.as_op_pin();
328 trace!("call OpCode::cancel");
330 unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
331 }
332
333 pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
334 instrument!(compio_log::Level::TRACE, "push", ?op);
335 let user_data = op.user_data();
336 trace!("push RawOp");
337 let optr = op.as_mut_ptr();
338 let op_pin = op.as_op_pin();
339 match op_pin.op_type() {
340 OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
341 OpType::Blocking => loop {
342 if self.push_blocking(user_data) {
343 break Poll::Pending;
344 } else {
345 unsafe {
348 self.poll(None)?;
349 }
350 }
351 },
352 OpType::Event(e) => {
353 self.waits
354 .insert(user_data, wait::Wait::new(&self.port, e, op)?);
355 Poll::Pending
356 }
357 }
358 }
359
360 fn push_blocking(&mut self, user_data: usize) -> bool {
361 let port = self.port.handle();
362 self.pool
363 .dispatch(move || {
364 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
365 let optr = op.as_mut_ptr();
366 let res = op.operate_blocking();
367 port.post(res, optr).ok();
368 })
369 .is_ok()
370 }
371
372 fn create_entry(
373 notify_user_data: usize,
374 waits: &mut HashMap<usize, wait::Wait>,
375 entry: Entry,
376 ) -> Option<Entry> {
377 let user_data = entry.user_data();
378 if user_data != notify_user_data {
379 if let Some(w) = waits.remove(&user_data) {
380 if w.is_cancelled() {
381 Some(Entry::new(
382 user_data,
383 Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
384 ))
385 } else if entry.result.is_err() {
386 Some(entry)
387 } else {
388 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
389 let result = op.operate_blocking();
390 Some(Entry::new(user_data, result))
391 }
392 } else {
393 Some(entry)
394 }
395 } else {
396 None
397 }
398 }
399
400 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
401 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
402
403 let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;
404
405 for e in self.port.poll(timeout)? {
406 if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
407 e.notify();
408 }
409 }
410
411 Ok(())
412 }
413
414 pub fn handle(&self) -> NotifyHandle {
415 NotifyHandle::new(self.port.handle(), self.notify_overlapped.clone())
416 }
417
418 pub fn create_buffer_pool(
419 &mut self,
420 buffer_len: u16,
421 buffer_size: usize,
422 ) -> io::Result<BufferPool> {
423 Ok(BufferPool::new(buffer_len, buffer_size))
424 }
425
426 pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
430 Ok(())
431 }
432}
433
434impl AsRawFd for Driver {
435 fn as_raw_fd(&self) -> RawFd {
436 self.port.as_raw_handle() as _
437 }
438}
439
440pub struct NotifyHandle {
442 port: cp::PortHandle,
443 overlapped: Arc<Overlapped>,
444}
445
446impl NotifyHandle {
447 fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
448 Self { port, overlapped }
449 }
450
451 pub fn notify(&self) -> io::Result<()> {
453 self.port.post_raw(self.overlapped.as_ref())
454 }
455}
456
457#[repr(C)]
459pub struct Overlapped {
460 pub base: OVERLAPPED,
462 pub driver: RawFd,
464}
465
466impl Overlapped {
467 pub(crate) fn new(driver: RawFd) -> Self {
468 Self {
469 base: unsafe { std::mem::zeroed() },
470 driver,
471 }
472 }
473}
474
475unsafe impl Send for Overlapped {}
477unsafe impl Sync for Overlapped {}