lazy_socket/raw/
windows.rs

1use std::io;
2use std::os::raw::*;
3use std::net;
4use std::mem;
5use std::ptr;
6use std::sync::{Once, ONCE_INIT};
7
8#[cfg(feature = "safe_buffer_len")]
9use std::cmp;
10
11mod winapi {
12    #![allow(bad_style)]
13    #![allow(dead_code)]
14
15    extern crate winapi;
16
17    pub type STD_SOCKET = ::std::os::windows::io::RawSocket;
18
19	pub use self::winapi::shared::minwindef::{
20		DWORD,
21		WORD,
22		USHORT
23	};
24
25	pub use self::winapi::shared::ntdef::{
26		CHAR,
27		HANDLE,
28    };
29
30    pub use self::winapi::shared::winerror::{
31        WSAESHUTDOWN,
32        WSAEINVAL
33    };
34
35    pub use self::winapi::shared::ws2def::{
36		ADDRESS_FAMILY,
37
38        AF_UNSPEC,
39        AF_INET,
40        AF_INET6,
41        AF_IRDA,
42        AF_BTH,
43
44        SOCK_STREAM,
45        SOCK_DGRAM,
46        SOCK_RAW,
47        SOCK_RDM,
48        SOCK_SEQPACKET,
49
50        IPPROTO_NONE,
51        IPPROTO_ICMP,
52        IPPROTO_TCP,
53        IPPROTO_UDP,
54        IPPROTO_ICMPV6,
55
56        SOCKADDR_STORAGE_LH,
57        SOCKADDR_IN,
58        SOCKADDR,
59    };
60
61    pub const SOCK_NONBLOCK: winapi::ctypes::c_int = 0o0004000;
62    pub const SOCK_CLOEXEC: winapi::ctypes::c_int = 0o2000000;
63
64    pub use self::winapi::shared::ws2ipdef::SOCKADDR_IN6_LH;
65
66    pub use self::winapi::shared::inaddr::{
67        in_addr,
68    };
69    pub use self::winapi::shared::in6addr::{
70        in6_addr,
71    };
72
73    pub use self::winapi::um::winsock2::{
74        SOCKET,
75		GROUP,
76
77        INVALID_SOCKET,
78        SOCKET_ERROR,
79        FIONBIO,
80
81        FD_SETSIZE,
82        WSADESCRIPTION_LEN,
83        WSASYS_STATUS_LEN,
84
85        WSADATA,
86        fd_set,
87        timeval,
88        LPWSADATA,
89
90        WSAStartup,
91        WSACleanup,
92
93        getsockname,
94        socket,
95        bind,
96        listen,
97        accept,
98        connect,
99        recv,
100        recvfrom,
101        send,
102        sendto,
103        getsockopt,
104        setsockopt,
105        ioctlsocket,
106        shutdown,
107        closesocket,
108        select
109    };
110
111    // Currently not available in `winapi`.
112    pub const HANDLE_FLAG_INHERIT: DWORD = 1;
113
114    pub use self::winapi::um::handleapi::{
115    	SetHandleInformation,
116    	GetHandleInformation
117    };
118}
119
120macro_rules! impl_into_trait {
121    ($($t:ty), +) => {
122        $(
123            impl Into<c_int> for $t {
124                fn into(self) -> c_int {
125                    self as c_int
126                }
127            }
128        )+
129    };
130}
131
132#[allow(non_snake_case, non_upper_case_globals)]
133///Socket family
134pub mod Family {
135    use super::{c_int, winapi};
136
137    pub const UNSPECIFIED: c_int = winapi::AF_UNSPEC;
138
139    pub const IPv4: c_int = winapi::AF_INET;
140    pub const IPv6: c_int = winapi::AF_INET6;
141    pub const IRDA: c_int = winapi::AF_IRDA;
142    pub const BTH:  c_int = winapi::AF_BTH;
143}
144
145#[allow(non_snake_case)]
146///Socket type
147pub mod Type {
148    use super::{c_int, winapi};
149
150    pub const STREAM:    c_int = winapi::SOCK_STREAM;
151    pub const DATAGRAM:  c_int = winapi::SOCK_DGRAM;
152    pub const RAW:       c_int = winapi::SOCK_RAW;
153    pub const RDM:       c_int = winapi::SOCK_RDM;
154    pub const SEQPACKET: c_int = winapi::SOCK_SEQPACKET;
155}
156
157#[allow(non_snake_case, non_upper_case_globals)]
158///Socket protocol
159pub mod Protocol {
160    use super::{c_int, winapi};
161
162    pub const NONE:   c_int = winapi::IPPROTO_NONE as i32;
163    pub const ICMPv4: c_int = winapi::IPPROTO_ICMP as i32;
164    pub const TCP:    c_int = winapi::IPPROTO_TCP as i32;
165    pub const UDP:    c_int = winapi::IPPROTO_UDP as i32;
166    pub const ICMPv6: c_int = winapi::IPPROTO_ICMPV6 as i32;
167}
168
169#[allow(non_snake_case)]
170///Possible flags for `accept4()`
171///
172///Note that these flags correspond to emulated constants that are not represented
173///in the OS in this way.
174bitflags! (pub flags AcceptFlags: c_int {
175    const NON_BLOCKING    = winapi::SOCK_NONBLOCK,
176    const NON_INHERITABLE = winapi::SOCK_CLOEXEC,
177});
178
179#[repr(i32)]
180#[derive(Copy, Clone)]
181///Type of socket's shutdown operation.
182pub enum ShutdownType {
183    ///Stops any further receives.
184    Receive = 0,
185    ///Stops any further sends.
186    Send = 1,
187    ///Stops both sends and receives.
188    Both = 2
189}
190
191impl_into_trait!(ShutdownType);
192
193///Raw socket
194pub struct Socket {
195    inner: winapi::SOCKET
196}
197
198impl Socket {
199    ///Initializes new socket.
200    ///
201    ///Corresponds to C connect()
202    pub fn new(family: c_int, _type: c_int, protocol: c_int) -> io::Result<Socket> {
203        static INIT: Once = ONCE_INIT;
204
205        INIT.call_once(|| {
206            //just to initialize winsock inside libstd
207            let _ = net::UdpSocket::bind("127.0.0.1:34254");
208        });
209
210        unsafe {
211            match winapi::socket(family, _type, protocol) {
212                winapi::INVALID_SOCKET => Err(io::Error::last_os_error()),
213                fd => Ok(Socket {
214                    inner: fd
215                }),
216            }
217        }
218    }
219
220    ///Returns underlying socket descriptor.
221    ///
222    ///Note: ownership is not transferred.
223    pub fn raw(&self) -> winapi::SOCKET {
224        self.inner
225    }
226
227    ///Retrieves socket name i.e. address
228    ///
229    ///Wraps `getsockname()`
230    ///
231    ///Available for binded/connected sockets.
232    pub fn name(&self) -> io::Result<net::SocketAddr> {
233        unsafe {
234            let mut storage: winapi::SOCKADDR_STORAGE_LH = mem::zeroed();
235            let mut len = mem::size_of_val(&storage) as c_int;
236
237            match winapi::getsockname(self.inner, &mut storage as *mut _ as *mut _, &mut len) {
238                winapi::SOCKET_ERROR => Err(io::Error::last_os_error()),
239                _ => sockaddr_to_addr(&storage, len)
240            }
241        }
242    }
243
244    ///Binds socket to address.
245    pub fn bind(&self, addr: &net::SocketAddr) -> io::Result<()> {
246        let (addr, len) = get_raw_addr(addr);
247
248        unsafe {
249            match winapi::bind(self.inner, addr, len) {
250                0 => Ok(()),
251                _ => Err(io::Error::last_os_error())
252            }
253        }
254    }
255
256    ///Listens for incoming connections on this socket.
257    pub fn listen(&self, backlog: c_int) -> io::Result<()> {
258        unsafe {
259            match winapi::listen(self.inner, backlog) {
260                0 => Ok(()),
261                _ => Err(io::Error::last_os_error())
262            }
263        }
264    }
265
266    ///Receives some bytes from socket
267    ///
268    ///Number of received bytes is returned on success
269    pub fn recv(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
270        #[cfg(feature = "safe_buffer_len")]
271        let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
272        #[cfg(not(feature = "safe_buffer_len"))]
273        let len = buf.len() as i32;
274
275        unsafe {
276            match winapi::recv(self.inner, buf.as_mut_ptr() as *mut c_char, len, flags) {
277                -1 => {
278                    let error = io::Error::last_os_error();
279                    let raw_code = error.raw_os_error().unwrap();
280
281                    if raw_code == winapi::WSAESHUTDOWN as i32 {
282                        Ok(0)
283                    }
284                    else {
285                        Err(error)
286                    }
287                },
288                n => Ok(n as usize)
289            }
290        }
291    }
292
293    ///Receives some bytes from socket
294    ///
295    ///Number of received bytes and remote address are returned on success.
296    pub fn recv_from(&self, buf: &mut [u8], flags: c_int) -> io::Result<(usize, net::SocketAddr)> {
297        #[cfg(feature = "safe_buffer_len")]
298        let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
299        #[cfg(not(feature = "safe_buffer_len"))]
300        let len = buf.len() as i32;
301
302        unsafe {
303            let mut storage: winapi::SOCKADDR_STORAGE_LH = mem::zeroed();
304            let mut storage_len = mem::size_of_val(&storage) as c_int;
305
306            match winapi::recvfrom(self.inner, buf.as_mut_ptr() as *mut c_char, len, flags, &mut storage as *mut _ as *mut _, &mut storage_len) {
307                -1 => {
308                    let error = io::Error::last_os_error();
309                    let raw_code = error.raw_os_error().unwrap();
310
311                    if raw_code == winapi::WSAESHUTDOWN as i32 {
312                        let peer_addr = sockaddr_to_addr(&storage, storage_len)?;
313                        Ok((0, peer_addr))
314                    }
315                    else {
316                        Err(error)
317                    }
318                },
319                n => {
320                    let peer_addr = sockaddr_to_addr(&storage, storage_len)?;
321                    Ok((n as usize, peer_addr))
322                }
323            }
324        }
325    }
326
327    ///Sends some bytes through socket.
328    ///
329    ///Number of sent bytes is returned.
330    pub fn send(&self, buf: &[u8], flags: c_int) -> io::Result<usize> {
331        #[cfg(feature = "safe_buffer_len")]
332        let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
333        #[cfg(not(feature = "safe_buffer_len"))]
334        let len = buf.len() as i32;
335
336        unsafe {
337            match winapi::send(self.inner, buf.as_ptr() as *const c_char, len, flags) {
338                -1 => {
339                    let error = io::Error::last_os_error();
340                    let raw_code = error.raw_os_error().unwrap();
341
342                    if raw_code == winapi::WSAESHUTDOWN as i32 {
343                        Ok(0)
344                    }
345                    else {
346                        Err(error)
347                    }
348                },
349                n => Ok(n as usize)
350            }
351        }
352    }
353
354    ///Sends some bytes through socket toward specified peer.
355    ///
356    ///Number of sent bytes is returned.
357    ///
358    ///Note: the socket will be bound, if it isn't already.
359    ///Use method `name` to determine address.
360    pub fn send_to(&self, buf: &[u8], peer_addr: &net::SocketAddr, flags: c_int) -> io::Result<usize> {
361        #[cfg(feature = "safe_buffer_len")]
362        let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
363        #[cfg(not(feature = "safe_buffer_len"))]
364        let len = buf.len() as i32;
365        let (addr, addr_len) = get_raw_addr(peer_addr);
366
367        unsafe {
368            match winapi::sendto(self.inner, buf.as_ptr() as *const c_char, len, flags, addr, addr_len) {
369                -1 => {
370                    let error = io::Error::last_os_error();
371                    let raw_code = error.raw_os_error().unwrap();
372
373                    if raw_code == winapi::WSAESHUTDOWN as i32 {
374                        Ok(0)
375                    }
376                    else {
377                        Err(error)
378                    }
379                },
380                n => Ok(n as usize)
381            }
382        }
383    }
384
385    ///Accept a new incoming client connection and return its files descriptor and address.
386    ///
387    ///This is an emulation of the corresponding Unix system call, that will automatically call
388    ///`.set_blocking` and `.set_inheritable` with parameter values based on the value of `flags`
389    ///on the created client socket:
390    ///
391    /// * `AcceptFlags::NON_BLOCKING`    – Mark the newly created socket as non-blocking
392    /// * `AcceptFlags::NON_INHERITABLE` – Mark the newly created socket as not inheritable by client processes
393    pub fn accept4(&self, flags: AcceptFlags) -> io::Result<(Socket, net::SocketAddr)> {
394        self.accept().map(|(sock, addr)| {
395            // Emulate the two most common (and useful) `accept4` flags
396            sock.set_blocking(!flags.contains(NON_BLOCKING)).expect("Setting newly obtained client socket blocking mode");
397            sock.set_inheritable(!flags.contains(NON_INHERITABLE)).expect("Setting newly obtained client socket inheritance mode");
398
399            (sock, addr)
400        })
401    }
402
403    ///Accepts incoming connection.
404    pub fn accept(&self) -> io::Result<(Socket, net::SocketAddr)> {
405        unsafe {
406            let mut storage: winapi::SOCKADDR_STORAGE_LH = mem::zeroed();
407            let mut len = mem::size_of_val(&storage) as c_int;
408
409            match winapi::accept(self.inner, &mut storage as *mut _ as *mut _, &mut len) {
410                winapi::INVALID_SOCKET => Err(io::Error::last_os_error()),
411                sock @ _ => {
412                    let addr = sockaddr_to_addr(&storage, len)?;
413                    Ok((Socket { inner: sock }, addr))
414                }
415            }
416        }
417    }
418
419    ///Connects socket with remote address.
420    pub fn connect(&self, addr: &net::SocketAddr) -> io::Result<()> {
421        let (addr, len) = get_raw_addr(addr);
422
423        unsafe {
424            match winapi::connect(self.inner, addr, len) {
425                0 => Ok(()),
426                _ => Err(io::Error::last_os_error())
427            }
428        }
429    }
430
431    ///Retrieves socket option.
432    pub fn get_opt<T>(&self, level: c_int, name: c_int) -> io::Result<T> {
433        unsafe {
434            let mut value: T = mem::zeroed();
435            let value_ptr = &mut value as *mut T as *mut c_char;
436            let mut value_len = mem::size_of::<T>() as c_int;
437
438            match winapi::getsockopt(self.inner, level, name, value_ptr, &mut value_len) {
439                0 => Ok(value),
440                _ => Err(io::Error::last_os_error())
441            }
442        }
443    }
444
445    ///Sets socket option
446    ///
447    ///Value is generally integer or C struct.
448    pub fn set_opt<T>(&self, level: c_int, name: c_int, value: T) -> io::Result<()> {
449        unsafe {
450            let value = &value as *const T as *const c_char;
451
452            match winapi::setsockopt(self.inner, level, name, value, mem::size_of::<T>() as c_int) {
453                0 => Ok(()),
454                _ => Err(io::Error::last_os_error())
455            }
456        }
457    }
458
459    ///Sets I/O parameters of socket.
460    ///
461    ///It uses `ioctlsocket` under hood.
462    pub fn ioctl(&self, request: c_int, value: c_ulong) -> io::Result<()> {
463        unsafe {
464            let mut value = value;
465            let value = &mut value as *mut c_ulong;
466
467            match winapi::ioctlsocket(self.inner, request, value) {
468                0 => Ok(()),
469                _ => Err(io::Error::last_os_error())
470            }
471        }
472    }
473
474    ///Sets non-blocking mode.
475    pub fn set_blocking(&self, value: bool) -> io::Result<()> {
476        self.ioctl(winapi::FIONBIO as c_int, (!value) as c_ulong)
477    }
478
479    ///Sets whether this socket will be inherited by child processes or not.
480    ///
481    ///Internally this implemented by calling `SetHandleInformation(sock, HANDLE_FLAG_INHERIT, …)`.
482    pub fn set_inheritable(&self, value: bool) -> io::Result<()> {
483        unsafe {
484            let flag = if value { winapi::HANDLE_FLAG_INHERIT } else { 0 };
485            match winapi::SetHandleInformation(self.inner as winapi::HANDLE, winapi::HANDLE_FLAG_INHERIT, flag) {
486                0 => Err(io::Error::last_os_error()),
487                _ => Ok(())
488            }
489        }
490    }
491
492	///Returns whether this socket will be inherited by child processes or not.
493	pub fn get_inheritable(&self) -> io::Result<bool> {
494		unsafe {
495			let mut flags: winapi::DWORD = 0;
496			match winapi::GetHandleInformation(self.inner as winapi::HANDLE, &mut flags as *mut _) {
497                0 => Err(io::Error::last_os_error()),
498                _ => Ok((flags & winapi::HANDLE_FLAG_INHERIT) != 0)
499            }
500        }
501	}
502
503    ///Stops receive and/or send over socket.
504    pub fn shutdown(&self, direction: ShutdownType) -> io::Result<()> {
505        unsafe {
506            match winapi::shutdown(self.inner, direction.into()) {
507                0 => Ok(()),
508                _ => Err(io::Error::last_os_error())
509            }
510        }
511    }
512
513    ///Closes socket.
514    ///
515    ///Note: on `Drop` socket will be closed on its own.
516    ///There is no need to close it explicitly.
517    pub fn close(&self) -> io::Result<()> {
518        unsafe {
519            match winapi::closesocket(self.inner) {
520                0 => Ok(()),
521                _ => Err(io::Error::last_os_error())
522            }
523        }
524    }
525}
526
527fn get_raw_addr(addr: &net::SocketAddr) -> (*const winapi::SOCKADDR, c_int) {
528    match *addr {
529        net::SocketAddr::V4(ref a) => {
530            (a as *const _ as *const _, mem::size_of_val(a) as c_int)
531        }
532        net::SocketAddr::V6(ref a) => {
533            (a as *const _ as *const _, mem::size_of_val(a) as c_int)
534        }
535    }
536}
537
538fn sockaddr_to_addr(storage: &winapi::SOCKADDR_STORAGE_LH, len: c_int) -> io::Result<net::SocketAddr> {
539    match storage.ss_family as c_int {
540        winapi::AF_INET => {
541            assert!(len as usize >= mem::size_of::<winapi::SOCKADDR_IN>());
542            let storage = unsafe { *(storage as *const _ as *const winapi::SOCKADDR_IN) };
543            let address = unsafe { storage.sin_addr.S_un.S_un_b() };
544            let ip = net::Ipv4Addr::new(address.s_b1,
545                                        address.s_b2,
546                                        address.s_b3,
547                                        address.s_b4);
548
549            //Note to_be() swap bytes on LE targets
550            //As IP stuff is always BE, we need swap only on LE targets
551            Ok(net::SocketAddr::V4(net::SocketAddrV4::new(ip, storage.sin_port.to_be())))
552        }
553        winapi::AF_INET6 => {
554            assert!(len as usize >= mem::size_of::<winapi::SOCKADDR_IN6_LH>());
555            let storage = unsafe { *(storage as *const _ as *const winapi::SOCKADDR_IN6_LH) };
556            let ip = unsafe { storage.sin6_addr.u.Byte().clone() };
557            let ip = net::Ipv6Addr::from(ip);
558
559            let scope = unsafe { *storage.u.sin6_scope_id() };
560
561            Ok(net::SocketAddr::V6(net::SocketAddrV6::new(ip, storage.sin6_port.to_be(), storage.sin6_flowinfo, scope)))
562        }
563        _ => {
564            Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid addr type."))
565        }
566    }
567}
568
569impl Drop for Socket {
570    fn drop(&mut self) {
571        let _ = self.shutdown(ShutdownType::Both);
572        let _ = self.close();
573    }
574}
575
576use std::os::windows::io::{
577    AsRawSocket,
578    FromRawSocket,
579    IntoRawSocket,
580};
581
582impl AsRawSocket for Socket {
583    fn as_raw_socket(&self) -> winapi::STD_SOCKET {
584        self.inner as winapi::STD_SOCKET
585    }
586}
587
588impl FromRawSocket for Socket {
589    unsafe fn from_raw_socket(sock: winapi::STD_SOCKET) -> Self {
590        Socket {inner: sock as winapi::SOCKET}
591    }
592}
593
594impl IntoRawSocket for Socket {
595    fn into_raw_socket(self) -> winapi::STD_SOCKET {
596        let result = self.inner;
597        mem::forget(self);
598        result as winapi::STD_SOCKET
599    }
600}
601
602#[inline]
603fn ms_to_timeval(timeout_ms: u64) -> winapi::timeval {
604    winapi::timeval {
605        tv_sec: timeout_ms as c_long / 1000,
606        tv_usec: (timeout_ms as c_long % 1000) * 1000
607    }
608}
609
610fn sockets_to_fd_set(sockets: &[&Socket]) -> winapi::fd_set {
611    assert!(sockets.len() < winapi::FD_SETSIZE);
612    let mut raw_fds: winapi::fd_set = unsafe { mem::zeroed() };
613
614    for socket in sockets {
615        let idx = raw_fds.fd_count as usize;
616        raw_fds.fd_array[idx] = socket.inner;
617        raw_fds.fd_count += 1;
618    }
619
620    raw_fds
621}
622
623///Wrapper over system `select`
624///
625///Returns number of sockets that are ready.
626///
627///If timeout isn't specified then select will be blocking call.
628///
629///## Note:
630///
631///Number of each set cannot be bigger than FD_SETSIZE i.e. 64
632///
633///## Warning:
634///
635///It is invalid to pass all sets of descriptors empty on Windows.
636pub fn select(read_fds: &[&Socket], write_fds: &[&Socket], except_fds: &[&Socket], timeout_ms: Option<u64>) -> io::Result<c_int> {
637    let mut raw_read_fds = sockets_to_fd_set(read_fds);
638    let mut raw_write_fds = sockets_to_fd_set(write_fds);
639    let mut raw_except_fds = sockets_to_fd_set(except_fds);
640
641    unsafe {
642        match winapi::select(0,
643                             if read_fds.len() > 0 { &mut raw_read_fds } else { ptr::null_mut() },
644                             if write_fds.len() > 0 { &mut raw_write_fds } else { ptr::null_mut() },
645                             if except_fds.len() > 0 { &mut raw_except_fds } else { ptr::null_mut() },
646                             if let Some(timeout_ms) = timeout_ms { &ms_to_timeval(timeout_ms) } else { ptr::null() } ) {
647            winapi::SOCKET_ERROR => Err(io::Error::last_os_error()),
648            result @ _ => Ok(result)
649
650        }
651    }
652}