libicmp/
socket.rs

1use std::io;
2use std::cmp;
3use std::mem;
4use std::net::SocketAddr;
5
6use libc;
7
8type FileDesc = libc::c_int;
9
10pub const IPPROTO_ICMP: libc::c_int = 1;
11
12fn cvt<T: IsMinusOne>(t: T) -> io::Result<T> {
13    if t.is_minus_one() {
14        Err(io::Error::last_os_error())
15    } else {
16        Ok(t)
17    }
18}
19
20trait IsMinusOne {
21    fn is_minus_one(&self) -> bool;
22}
23
24macro_rules! impl_is_minus_one {
25    ($($t:ident)*) => ($(impl IsMinusOne for $t {
26        fn is_minus_one(&self) -> bool {
27            *self == -1
28        }
29    })*)
30}
31
32impl_is_minus_one! { i8 i16 i32 i64 isize }
33
34/// Sends and receives messages over a SOCK_RAW.
35pub struct RawSocket(FileDesc);
36
37impl RawSocket {
38
39    /// Returns a new RawSocket for sending and receiving messages on
40    /// IPPROTO_ICMP over a SOCK_RAW. Currently only supports IPv4.
41    /// TODO: use libc::IPPROTO* when the const becomes available,
42    ///       and add some config for IPv4/IPv6.
43    pub fn new() -> io::Result<RawSocket> {
44        unsafe {
45            // try to open with SOCK_CLOEXEC, otherwise use a fallback
46            if cfg!(target_os = "linux") {
47                match cvt(libc::socket(libc::AF_INET, libc::SOCK_RAW | libc::SOCK_CLOEXEC, IPPROTO_ICMP)) {
48                    Ok(fd) => return Ok(RawSocket(fd)),
49                    Err(ref e) if e.raw_os_error() == Some(libc::EINVAL) => {}
50                    Err(e) => return Err(e),
51                }
52            }
53
54            let fd = cvt(libc::socket(libc::AF_INET, libc::SOCK_RAW, IPPROTO_ICMP))?;
55            cvt(libc::ioctl(fd, libc::FIOCLEX))?;
56            let socket = RawSocket(fd);
57            Ok(socket)
58        }
59    }
60
61    /// Returns the socket's underlying file descriptor.
62    pub fn fd(&self) -> libc::c_int {
63        self.0
64    }
65
66    /// Sets the socket to non-blocking mode so that reads return immediately.
67    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
68        let mut nonblocking = nonblocking as libc::c_int;
69        cvt(unsafe { libc::ioctl(self.0, libc::FIONBIO, &mut nonblocking) }).map(|_| ())
70    }
71
72    /// Sends a set of bytes over the socket and returns the number of bytes written.
73    pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> {
74        let len = cmp::min(buf.len(), <libc::size_t>::max_value() as usize) as libc::size_t;
75        let (dstp, dstlen) = into_inner(dst);
76
77        let ret = cvt(unsafe {
78            libc::sendto(self.0, buf.as_ptr() as *const libc::c_void, len, 0, dstp, dstlen)
79        })?;
80        Ok(ret as usize)
81    }
82
83    /// Reads the next available packet into the buffer and returns the number
84    /// of bytes read. The packet is completely consumed, even if it is only
85    /// partially read.
86    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<usize> {
87        let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
88        let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;
89
90        let n = cvt(unsafe {
91            libc::recvfrom(self.0, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), 0, &mut storage as *mut _ as *mut _, &mut addrlen)
92        })?;
93        Ok(n as usize)
94    }
95}
96
97impl Drop for RawSocket {
98    fn drop(&mut self) {
99        unsafe {
100            libc::close(self.0);
101        }
102    }
103}
104
105fn into_inner(s: &SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) {
106    match *s {
107        SocketAddr::V4(ref a) => {
108            (a as *const _ as *const _, mem::size_of_val(a) as libc::socklen_t)
109        }
110        SocketAddr::V6(ref a) => {
111            (a as *const _ as *const _, mem::size_of_val(a) as libc::socklen_t)
112        }
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use std::fs::read_dir;
119    use std::mem::drop;
120    use super::RawSocket;
121
122    #[test]
123    #[cfg(target_os = "linux")]
124    fn it_closes_sockets() {
125        let initial_descriptors = read_dir("/proc/self/fd").unwrap().count();
126
127        for _ in 0..5 {
128            drop(RawSocket::new().unwrap());
129        }
130
131        let final_descriptors = read_dir("/proc/self/fd").unwrap().count();
132
133        assert_eq!(initial_descriptors, final_descriptors);
134    }
135}