1use std::{
2 collections::HashMap,
3 io,
4 os::windows::io::{
5 AsHandle, AsRawHandle, AsRawSocket, AsSocket, BorrowedHandle, BorrowedSocket, OwnedHandle,
6 OwnedSocket,
7 },
8 pin::Pin,
9 sync::Arc,
10 task::Poll,
11 time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{Foundation::ERROR_CANCELLED, System::IO::OVERLAPPED};
16
17use crate::{AsyncifyPool, BufferPool, Entry, Key, ProactorBuilder};
18
19pub(crate) mod op;
20
21mod cp;
22mod wait;
23
24pub(crate) use windows_sys::Win32::Networking::WinSock::{
25 SOCKADDR_STORAGE as sockaddr_storage, socklen_t,
26};
27
28pub type RawFd = isize;
32
33pub trait AsRawFd {
35 fn as_raw_fd(&self) -> RawFd;
37}
38
39#[derive(Debug)]
41pub enum OwnedFd {
42 File(OwnedHandle),
44 Socket(OwnedSocket),
46}
47
48impl AsRawFd for OwnedFd {
49 fn as_raw_fd(&self) -> RawFd {
50 match self {
51 Self::File(fd) => fd.as_raw_handle() as _,
52 Self::Socket(s) => s.as_raw_socket() as _,
53 }
54 }
55}
56
57impl AsRawFd for RawFd {
58 fn as_raw_fd(&self) -> RawFd {
59 *self
60 }
61}
62
63impl AsRawFd for std::fs::File {
64 fn as_raw_fd(&self) -> RawFd {
65 self.as_raw_handle() as _
66 }
67}
68
69impl AsRawFd for OwnedHandle {
70 fn as_raw_fd(&self) -> RawFd {
71 self.as_raw_handle() as _
72 }
73}
74
75impl AsRawFd for socket2::Socket {
76 fn as_raw_fd(&self) -> RawFd {
77 self.as_raw_socket() as _
78 }
79}
80
81impl AsRawFd for OwnedSocket {
82 fn as_raw_fd(&self) -> RawFd {
83 self.as_raw_socket() as _
84 }
85}
86
87impl AsRawFd for std::process::ChildStdin {
88 fn as_raw_fd(&self) -> RawFd {
89 self.as_raw_handle() as _
90 }
91}
92
93impl AsRawFd for std::process::ChildStdout {
94 fn as_raw_fd(&self) -> RawFd {
95 self.as_raw_handle() as _
96 }
97}
98
99impl AsRawFd for std::process::ChildStderr {
100 fn as_raw_fd(&self) -> RawFd {
101 self.as_raw_handle() as _
102 }
103}
104
105impl From<OwnedHandle> for OwnedFd {
106 fn from(value: OwnedHandle) -> Self {
107 Self::File(value)
108 }
109}
110
111impl From<std::fs::File> for OwnedFd {
112 fn from(value: std::fs::File) -> Self {
113 Self::File(OwnedHandle::from(value))
114 }
115}
116
117impl From<std::process::ChildStdin> for OwnedFd {
118 fn from(value: std::process::ChildStdin) -> Self {
119 Self::File(OwnedHandle::from(value))
120 }
121}
122
123impl From<std::process::ChildStdout> for OwnedFd {
124 fn from(value: std::process::ChildStdout) -> Self {
125 Self::File(OwnedHandle::from(value))
126 }
127}
128
129impl From<std::process::ChildStderr> for OwnedFd {
130 fn from(value: std::process::ChildStderr) -> Self {
131 Self::File(OwnedHandle::from(value))
132 }
133}
134
135impl From<OwnedSocket> for OwnedFd {
136 fn from(value: OwnedSocket) -> Self {
137 Self::Socket(value)
138 }
139}
140
141impl From<socket2::Socket> for OwnedFd {
142 fn from(value: socket2::Socket) -> Self {
143 Self::Socket(OwnedSocket::from(value))
144 }
145}
146
147#[derive(Debug)]
149pub enum BorrowedFd<'a> {
150 File(BorrowedHandle<'a>),
152 Socket(BorrowedSocket<'a>),
154}
155
156impl AsRawFd for BorrowedFd<'_> {
157 fn as_raw_fd(&self) -> RawFd {
158 match self {
159 Self::File(fd) => fd.as_raw_handle() as RawFd,
160 Self::Socket(s) => s.as_raw_socket() as RawFd,
161 }
162 }
163}
164
165impl<'a> From<BorrowedHandle<'a>> for BorrowedFd<'a> {
166 fn from(value: BorrowedHandle<'a>) -> Self {
167 Self::File(value)
168 }
169}
170
171impl<'a> From<BorrowedSocket<'a>> for BorrowedFd<'a> {
172 fn from(value: BorrowedSocket<'a>) -> Self {
173 Self::Socket(value)
174 }
175}
176
177pub trait AsFd {
179 fn as_fd(&self) -> BorrowedFd<'_>;
181}
182
183impl AsFd for OwnedFd {
184 fn as_fd(&self) -> BorrowedFd<'_> {
185 match self {
186 Self::File(fd) => fd.as_fd(),
187 Self::Socket(s) => s.as_fd(),
188 }
189 }
190}
191
192impl AsFd for std::fs::File {
193 fn as_fd(&self) -> BorrowedFd<'_> {
194 self.as_handle().into()
195 }
196}
197
198impl AsFd for OwnedHandle {
199 fn as_fd(&self) -> BorrowedFd<'_> {
200 self.as_handle().into()
201 }
202}
203
204impl AsFd for socket2::Socket {
205 fn as_fd(&self) -> BorrowedFd<'_> {
206 self.as_socket().into()
207 }
208}
209
210impl AsFd for OwnedSocket {
211 fn as_fd(&self) -> BorrowedFd<'_> {
212 self.as_socket().into()
213 }
214}
215
216impl AsFd for std::process::ChildStdin {
217 fn as_fd(&self) -> BorrowedFd<'_> {
218 self.as_handle().into()
219 }
220}
221
222impl AsFd for std::process::ChildStdout {
223 fn as_fd(&self) -> BorrowedFd<'_> {
224 self.as_handle().into()
225 }
226}
227
228impl AsFd for std::process::ChildStderr {
229 fn as_fd(&self) -> BorrowedFd<'_> {
230 self.as_handle().into()
231 }
232}
233
234pub enum OpType {
236 Overlapped,
238 Blocking,
241 Event(RawFd),
245}
246
247pub trait OpCode {
249 fn op_type(&self) -> OpType {
252 OpType::Overlapped
253 }
254
255 unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
269
270 unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
278 let _optr = optr; Ok(())
280 }
281}
282
283pub(crate) struct Driver {
285 port: cp::Port,
286 waits: HashMap<usize, wait::Wait>,
287 pool: AsyncifyPool,
288 notify_overlapped: Arc<Overlapped>,
289}
290
291impl Driver {
292 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
293 instrument!(compio_log::Level::TRACE, "new", ?builder);
294
295 let port = cp::Port::new()?;
296 let driver = port.as_raw_handle() as _;
297 Ok(Self {
298 port,
299 waits: HashMap::default(),
300 pool: builder.create_or_get_thread_pool(),
301 notify_overlapped: Arc::new(Overlapped::new(driver)),
302 })
303 }
304
305 pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
306 Key::new(self.port.as_raw_handle() as _, op)
307 }
308
309 pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
310 self.port.attach(fd)
311 }
312
313 pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
314 instrument!(compio_log::Level::TRACE, "cancel", ?op);
315 trace!("cancel RawOp");
316 let overlapped_ptr = op.as_mut_ptr();
317 if let Some(w) = self.waits.get_mut(&op.user_data()) {
318 if w.cancel().is_ok() {
319 self.port.post_raw(overlapped_ptr).ok();
322 }
323 }
324 let op = op.as_op_pin();
325 trace!("call OpCode::cancel");
327 unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
328 }
329
330 pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
331 instrument!(compio_log::Level::TRACE, "push", ?op);
332 let user_data = op.user_data();
333 trace!("push RawOp");
334 let optr = op.as_mut_ptr();
335 let op_pin = op.as_op_pin();
336 match op_pin.op_type() {
337 OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
338 OpType::Blocking => loop {
339 if self.push_blocking(user_data) {
340 break Poll::Pending;
341 } else {
342 unsafe {
345 self.poll(None)?;
346 }
347 }
348 },
349 OpType::Event(e) => {
350 self.waits
351 .insert(user_data, wait::Wait::new(&self.port, e, op)?);
352 Poll::Pending
353 }
354 }
355 }
356
357 fn push_blocking(&mut self, user_data: usize) -> bool {
358 let port = self.port.handle();
359 self.pool
360 .dispatch(move || {
361 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
362 let optr = op.as_mut_ptr();
363 let res = op.operate_blocking();
364 port.post(res, optr).ok();
365 })
366 .is_ok()
367 }
368
369 fn create_entry(
370 notify_user_data: usize,
371 waits: &mut HashMap<usize, wait::Wait>,
372 entry: Entry,
373 ) -> Option<Entry> {
374 let user_data = entry.user_data();
375 if user_data != notify_user_data {
376 if let Some(w) = waits.remove(&user_data) {
377 if w.is_cancelled() {
378 Some(Entry::new(
379 user_data,
380 Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
381 ))
382 } else if entry.result.is_err() {
383 Some(entry)
384 } else {
385 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
386 let result = op.operate_blocking();
387 Some(Entry::new(user_data, result))
388 }
389 } else {
390 Some(entry)
391 }
392 } else {
393 None
394 }
395 }
396
397 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
398 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
399
400 let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;
401
402 for e in self.port.poll(timeout)? {
403 if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
404 e.notify();
405 }
406 }
407
408 Ok(())
409 }
410
411 pub fn handle(&self) -> NotifyHandle {
412 NotifyHandle::new(self.port.handle(), self.notify_overlapped.clone())
413 }
414
415 pub fn create_buffer_pool(
416 &mut self,
417 buffer_len: u16,
418 buffer_size: usize,
419 ) -> io::Result<BufferPool> {
420 Ok(BufferPool::new(buffer_len, buffer_size))
421 }
422
423 pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
427 Ok(())
428 }
429}
430
431impl AsRawFd for Driver {
432 fn as_raw_fd(&self) -> RawFd {
433 self.port.as_raw_handle() as _
434 }
435}
436
437pub struct NotifyHandle {
439 port: cp::PortHandle,
440 overlapped: Arc<Overlapped>,
441}
442
443impl NotifyHandle {
444 fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
445 Self { port, overlapped }
446 }
447
448 pub fn notify(&self) -> io::Result<()> {
450 self.port.post_raw(self.overlapped.as_ref())
451 }
452}
453
454#[repr(C)]
456pub struct Overlapped {
457 pub base: OVERLAPPED,
459 pub driver: RawFd,
461}
462
463impl Overlapped {
464 pub(crate) fn new(driver: RawFd) -> Self {
465 Self {
466 base: unsafe { std::mem::zeroed() },
467 driver,
468 }
469 }
470}
471
472unsafe impl Send for Overlapped {}
474unsafe impl Sync for Overlapped {}