raw_socket/
socket.rs

1// Copyright (C) 2020 - Will Glozer. All rights reserved.
2
3use std::io::{Error, ErrorKind, IoSlice, IoSliceMut, Result};
4use std::mem::{size_of, transmute, zeroed};
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::os::unix::io::{AsRawFd, RawFd};
7use libc::{AF_INET, AF_INET6, c_int, msghdr, sockaddr_storage, socklen_t};
8use socket2::{Socket, SockAddr};
9use crate::{Domain, Type, Protocol};
10use crate::option::{Level, Name, Opt};
11
12pub struct RawSocket {
13    sys: Socket,
14}
15
16impl RawSocket {
17    pub fn new(domain: Domain, kind: Type, protocol: Option<Protocol>) -> Result<Self> {
18        let sys = Socket::new(domain, kind, protocol)?;
19        Ok(Self { sys })
20    }
21
22    pub fn bind<A: ToSocketAddrs>(&self, addr: A) -> Result<()> {
23        self.sys.bind(&sockaddr(addr)?)
24    }
25
26    pub fn local_addr(&self) -> Result<SocketAddr> {
27        socketaddr(&self.sys.local_addr()?)
28    }
29
30    pub fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
31        let (n, addr) = self.sys.recv_from(buf)?;
32        Ok((n, socketaddr(&addr)?))
33    }
34
35    pub fn recv_msg(
36        &self,
37        data: &[IoSliceMut<'_>],
38        ctrl: &mut [u8]
39    ) -> Result<(usize, SocketAddr)> {
40        let fd = self.as_raw_fd();
41        unsafe {
42            let mut addr: sockaddr_storage = zeroed();
43            let addr    = &mut addr as *mut _;
44            let addrlen = size_of::<sockaddr_storage>();
45
46            let mut msg: msghdr = zeroed();
47            msg.msg_name    = addr          as *mut _;
48            msg.msg_namelen = addrlen       as      _;
49            msg.msg_iov     = data.as_ptr() as *mut _;
50            msg.msg_iovlen  = data.len()    as      _;
51
52            if !ctrl.is_empty() {
53                msg.msg_control    = ctrl.as_ptr() as *mut _;
54                msg.msg_controllen = ctrl.len()    as      _;
55            }
56
57            let n = match libc::recvmsg(fd, &mut msg, 0) {
58                n if n >= 0 => n as usize,
59                _           => Err(Error::last_os_error())?,
60            };
61
62            let addr = msg.msg_name as *const _;
63            let len  = msg.msg_namelen;
64            let addr = socketaddr(&SockAddr::from_raw_parts(addr, len))?;
65
66            Ok((n, addr))
67        }
68    }
69
70    pub fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], addr: A) -> Result<usize> {
71        self.sys.send_to(buf, &sockaddr(addr)?)
72    }
73
74    pub fn send_msg<A: ToSocketAddrs>(
75        &self,
76        addr: A,
77        data: &[IoSlice<'_>],
78        ctrl: &[u8],
79    ) -> Result<usize> {
80        let fd   = self.as_raw_fd();
81        let addr = sockaddr(addr)?;
82
83        unsafe {
84            let mut msg: msghdr = zeroed();
85            msg.msg_name    = addr.as_ptr() as      _;
86            msg.msg_namelen = addr.len()    as      _;
87            msg.msg_iov     = data.as_ptr() as *mut _;
88            msg.msg_iovlen  = data.len()    as      _;
89
90            if !ctrl.is_empty() {
91                msg.msg_control    = ctrl.as_ptr() as *mut _;
92                msg.msg_controllen = ctrl.len()    as      _;
93            }
94
95            match libc::sendmsg(fd, &msg, 0) {
96                n if n >= 0 => Ok(n as usize),
97                _           => Err(Error::last_os_error()),
98            }
99        }
100    }
101
102    pub fn get_sockopt<O: Opt>(&self, level: Level, name: Name) -> Result<O> {
103        let fd = self.as_raw_fd();
104
105        let mut val = O::default();
106        let mut len = size_of::<O>() as socklen_t;
107
108        let ptr = &mut val as *mut _ as *mut _;
109        let len = &mut len as *mut _;
110
111        unsafe {
112            let level = transmute(level);
113            let name  = transmute(name);
114            match libc::getsockopt(fd, level, name, ptr, len) {
115                0 => Ok(val),
116                _ => Err(Error::last_os_error()),
117            }
118        }
119    }
120
121    pub fn set_sockopt<O: Opt>(&self, level: Level, name: Name, value: &O) -> Result<()> {
122        let fd  = self.as_raw_fd();
123        let ptr = value as *const _ as *const _;
124        let len = size_of::<O>() as socklen_t;
125
126        unsafe {
127            let level = transmute(level);
128            let name  = transmute(name);
129            match libc::setsockopt(fd, level, name, ptr, len) {
130                0 => Ok(()),
131                _ => Err(Error::last_os_error()),
132            }
133        }
134    }
135
136    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
137        self.sys.set_nonblocking(nonblocking)
138    }
139}
140
141impl AsRawFd for RawSocket {
142    fn as_raw_fd(&self) -> RawFd {
143        self.sys.as_raw_fd()
144    }
145}
146
147fn sockaddr<A: ToSocketAddrs>(addr: A) -> Result<SockAddr> {
148    match addr.to_socket_addrs()?.next() {
149        Some(addr) => Ok(SockAddr::from(addr)),
150        None       => Err(Error::new(ErrorKind::InvalidInput, "invalid socket address")),
151    }
152}
153
154fn socketaddr(addr: &SockAddr) -> Result<SocketAddr> {
155    match addr.family() as c_int {
156        AF_INET  => Ok(addr.as_inet().expect("AF_INET addr").into()),
157        AF_INET6 => Ok(addr.as_inet6().expect("AF_INET6 addr").into()),
158        _        => Err(Error::new(ErrorKind::Other, "unknown address type")),
159    }
160}