1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
//! Bindings to IOCP, I/O Completion Ports

use std::io;
use std::mem;
use std::os::windows::io::*;

use handle::Handle;
use winapi::*;
use kernel32::*;
use Overlapped;

/// A handle to an Windows I/O Completion Port.
#[derive(Debug)]
pub struct CompletionPort {
    handle: Handle,
}

/// A status message received from an I/O completion port.
///
/// These statuses can be created via the `new` or `empty` constructors and then
/// provided to a completion port, or they are read out of a completion port.
/// The fields of each status are read through its accessor methods.
#[derive(Clone, Copy, Debug)]
pub struct CompletionStatus(OVERLAPPED_ENTRY);

unsafe impl Send for CompletionStatus {}
unsafe impl Sync for CompletionStatus {}

impl CompletionPort {
    /// Creates a new I/O completion port with the specified concurrency value.
    ///
    /// The number of threads given corresponds to the level of concurrency
    /// allowed for threads associated with this port. Consult the Windows
    /// documentation for more information about this value.
    pub fn new(threads: u32) -> io::Result<CompletionPort> {
        let ret = unsafe {
            CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0 as *mut _,
                                   0, threads)
        };
        if ret.is_null() {
            Err(io::Error::last_os_error())
        } else {
            Ok(CompletionPort { handle: Handle::new(ret) })
        }
    }

    /// Associates a new `HANDLE` to this I/O completion port.
    ///
    /// This function will associate the given handle to this port with the
    /// given `token` to be returned in status messages whenever it receives a
    /// notification.
    ///
    /// Any object which is convertible to a `HANDLE` via the `AsRawHandle`
    /// trait can be provided to this function, such as `std::fs::File` and
    /// friends.
    pub fn add_handle<T: AsRawHandle + ?Sized>(&self, token: usize,
                                               t: &T) -> io::Result<()> {
        self._add(token, t.as_raw_handle())
    }

    /// Associates a new `SOCKET` to this I/O completion port.
    ///
    /// This function will associate the given socket to this port with the
    /// given `token` to be returned in status messages whenever it receives a
    /// notification.
    ///
    /// Any object which is convertible to a `SOCKET` via the `AsRawSocket`
    /// trait can be provided to this function, such as `std::net::TcpStream`
    /// and friends.
    pub fn add_socket<T: AsRawSocket + ?Sized>(&self, token: usize,
                                               t: &T) -> io::Result<()> {
        self._add(token, t.as_raw_socket() as HANDLE)
    }

    fn _add(&self, token: usize, handle: HANDLE) -> io::Result<()> {
        let ret = unsafe {
            CreateIoCompletionPort(handle, self.handle.raw(),
                                   token as ULONG_PTR, 0)
        };
        if ret.is_null() {
            Err(io::Error::last_os_error())
        } else {
            debug_assert_eq!(ret, self.handle.raw());
            Ok(())
        }
    }

    /// Dequeue a completion status from this I/O completion port.
    ///
    /// This function will associate the calling thread with this completion
    /// port and then wait for a status message to become available. The precise
    /// semantics on when this function returns depends on the concurrency value
    /// specified when the port was created.
    ///
    /// A timeout (in milliseconds) can optionally be specified to this
    /// function. If `None` is provided this function will not time out, and
    /// otherwise it will time out after the specified duration has passed.
    ///
    /// On success this will return the status message which was dequeued from
    /// this completion port.
    pub fn get(&self, timeout_ms: Option<u32>) -> io::Result<CompletionStatus> {
        let mut bytes = 0;
        let mut token = 0;
        let mut overlapped = 0 as *mut _;
        let timeout = timeout_ms.unwrap_or(INFINITE);
        let ret = unsafe {
            GetQueuedCompletionStatus(self.handle.raw(),
                                      &mut bytes,
                                      &mut token,
                                      &mut overlapped,
                                      timeout)
        };
        ::cvt(ret).map(|_| {
            CompletionStatus(OVERLAPPED_ENTRY {
                dwNumberOfBytesTransferred: bytes,
                lpCompletionKey: token,
                lpOverlapped: overlapped,
                Internal: 0,
            })
        })
    }

    /// Dequeues a number of completion statuses from this I/O completion port.
    ///
    /// This function is the same as `get` except that it may return more than
    /// one status. A buffer of "zero" statuses is provided (the contents are
    /// not read) and then on success this function will return a sub-slice of
    /// statuses which represent those which were dequeued from this port. This
    /// function does not wait to fill up the entire list of statuses provided.
    ///
    /// Like with `get`, a timeout may be specified for this operation.
    pub fn get_many<'a>(&self,
                        list: &'a mut [CompletionStatus],
                        timeout_ms: Option<u32>)
                        -> io::Result<&'a mut [CompletionStatus]>
    {
        debug_assert_eq!(mem::size_of::<CompletionStatus>(),
                         mem::size_of::<OVERLAPPED_ENTRY>());
        let mut removed = 0;
        let timeout = timeout_ms.unwrap_or(INFINITE);
        let ret = unsafe {
            GetQueuedCompletionStatusEx(self.handle.raw(),
                                        list.as_ptr() as *mut _,
                                        list.len() as ULONG,
                                        &mut removed,
                                        timeout,
                                        FALSE)
        };
        match ::cvt(ret) {
            Ok(_) => Ok(&mut list[..removed as usize]),
            Err(e) => Err(e),
        }
    }

    /// Posts a new completion status onto this I/O completion port.
    ///
    /// This function will post the given status, with custom parameters, to the
    /// port. Threads blocked in `get` or `get_many` will eventually receive
    /// this status.
    pub fn post(&self, status: CompletionStatus) -> io::Result<()> {
        let ret = unsafe {
            PostQueuedCompletionStatus(self.handle.raw(),
                                       status.0.dwNumberOfBytesTransferred,
                                       status.0.lpCompletionKey,
                                       status.0.lpOverlapped)
        };
        ::cvt(ret).map(|_| ())
    }
}

impl AsRawHandle for CompletionPort {
    fn as_raw_handle(&self) -> HANDLE {
        self.handle.raw()
    }
}

impl FromRawHandle for CompletionPort {
    unsafe fn from_raw_handle(handle: HANDLE) -> CompletionPort {
        CompletionPort { handle: Handle::new(handle) }
    }
}

impl IntoRawHandle for CompletionPort {
    fn into_raw_handle(self) -> HANDLE {
        self.handle.into_raw()
    }
}

impl CompletionStatus {
    /// Creates a new completion status with the provided parameters.
    ///
    /// This function is useful when creating a status to send to a port with
    /// the `post` method. The parameters are opaquely passed through and not
    /// interpreted by the system at all.
    pub fn new(bytes: u32, token: usize, overlapped: *mut Overlapped)
               -> CompletionStatus {
        CompletionStatus(OVERLAPPED_ENTRY {
            dwNumberOfBytesTransferred: bytes,
            lpCompletionKey: token as ULONG_PTR,
            lpOverlapped: overlapped as *mut _,
            Internal: 0,
        })
    }

    /// Creates a new "zero" completion status.
    ///
    /// This function is useful when creating a stack buffer or vector of
    /// completion statuses to be passed to the `get_many` function.
    pub fn zero() -> CompletionStatus {
        CompletionStatus::new(0, 0, 0 as *mut _)
    }

    /// Returns the number of bytes that were transferred for the I/O operation
    /// associated with this completion status.
    pub fn bytes_transferred(&self) -> u32 {
        self.0.dwNumberOfBytesTransferred
    }

    /// Returns the completion key value associated with the file handle whose
    /// I/O operation has completed.
    ///
    /// A completion key is a per-handle key that is specified when it is added
    /// to an I/O completion port via `add_handle` or `add_socket`.
    pub fn token(&self) -> usize {
        self.0.lpCompletionKey as usize
    }

    /// Returns a pointer to the `Overlapped` structure that was specified when
    /// the I/O operation was started.
    pub fn overlapped(&self) -> *mut Overlapped {
        self.0.lpOverlapped as *mut _
    }
}

#[cfg(test)]
mod tests {
    use std::mem;
    use winapi::*;

    use iocp::{CompletionPort, CompletionStatus};

    #[test]
    fn is_send_sync() {
        fn is_send_sync<T: Send + Sync>() {}
        is_send_sync::<CompletionPort>();
    }

    #[test]
    fn token_right_size() {
        assert_eq!(mem::size_of::<usize>(), mem::size_of::<ULONG_PTR>());
    }

    #[test]
    fn timeout() {
        let c = CompletionPort::new(1).unwrap();
        let err = c.get(Some(1)).unwrap_err();
        assert_eq!(err.raw_os_error(), Some(WAIT_TIMEOUT as i32));
    }

    #[test]
    fn get() {
        let c = CompletionPort::new(1).unwrap();
        c.post(CompletionStatus::new(1, 2, 3 as *mut _)).unwrap();
        let s = c.get(None).unwrap();
        assert_eq!(s.bytes_transferred(), 1);
        assert_eq!(s.token(), 2);
        assert_eq!(s.overlapped(), 3 as *mut _);
    }

    #[test]
    fn get_many() {
        let c = CompletionPort::new(1).unwrap();

        c.post(CompletionStatus::new(1, 2, 3 as *mut _)).unwrap();
        c.post(CompletionStatus::new(4, 5, 6 as *mut _)).unwrap();

        let mut s = vec![CompletionStatus::zero(); 4];
        {
            let s = c.get_many(&mut s, None).unwrap();
            assert_eq!(s.len(), 2);
            assert_eq!(s[0].bytes_transferred(), 1);
            assert_eq!(s[0].token(), 2);
            assert_eq!(s[0].overlapped(), 3 as *mut _);
            assert_eq!(s[1].bytes_transferred(), 4);
            assert_eq!(s[1].token(), 5);
            assert_eq!(s[1].overlapped(), 6 as *mut _);
        }
        assert_eq!(s[2].bytes_transferred(), 0);
        assert_eq!(s[2].token(), 0);
        assert_eq!(s[2].overlapped(), 0 as *mut _);
    }
}