1use std::io::{Error, ErrorKind, IoSlice, IoSliceMut, Result};
4use std::mem::{size_of, transmute, zeroed};
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::os::unix::io::{AsRawFd, RawFd};
7use libc::{AF_INET, AF_INET6, c_int, msghdr, sockaddr_storage, socklen_t};
8use socket2::{Socket, SockAddr};
9use crate::{Domain, Type, Protocol};
10use crate::option::{Level, Name, Opt};
11
12pub struct RawSocket {
13 sys: Socket,
14}
15
16impl RawSocket {
17 pub fn new(domain: Domain, kind: Type, protocol: Option<Protocol>) -> Result<Self> {
18 let sys = Socket::new(domain, kind, protocol)?;
19 Ok(Self { sys })
20 }
21
22 pub fn bind<A: ToSocketAddrs>(&self, addr: A) -> Result<()> {
23 self.sys.bind(&sockaddr(addr)?)
24 }
25
26 pub fn local_addr(&self) -> Result<SocketAddr> {
27 socketaddr(&self.sys.local_addr()?)
28 }
29
30 pub fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
31 let (n, addr) = self.sys.recv_from(buf)?;
32 Ok((n, socketaddr(&addr)?))
33 }
34
35 pub fn recv_msg(
36 &self,
37 data: &[IoSliceMut<'_>],
38 ctrl: &mut [u8]
39 ) -> Result<(usize, SocketAddr)> {
40 let fd = self.as_raw_fd();
41 unsafe {
42 let mut addr: sockaddr_storage = zeroed();
43 let addr = &mut addr as *mut _;
44 let addrlen = size_of::<sockaddr_storage>();
45
46 let mut msg: msghdr = zeroed();
47 msg.msg_name = addr as *mut _;
48 msg.msg_namelen = addrlen as _;
49 msg.msg_iov = data.as_ptr() as *mut _;
50 msg.msg_iovlen = data.len() as _;
51
52 if !ctrl.is_empty() {
53 msg.msg_control = ctrl.as_ptr() as *mut _;
54 msg.msg_controllen = ctrl.len() as _;
55 }
56
57 let n = match libc::recvmsg(fd, &mut msg, 0) {
58 n if n >= 0 => n as usize,
59 _ => Err(Error::last_os_error())?,
60 };
61
62 let addr = msg.msg_name as *const _;
63 let len = msg.msg_namelen;
64 let addr = socketaddr(&SockAddr::from_raw_parts(addr, len))?;
65
66 Ok((n, addr))
67 }
68 }
69
70 pub fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], addr: A) -> Result<usize> {
71 self.sys.send_to(buf, &sockaddr(addr)?)
72 }
73
74 pub fn send_msg<A: ToSocketAddrs>(
75 &self,
76 addr: A,
77 data: &[IoSlice<'_>],
78 ctrl: &[u8],
79 ) -> Result<usize> {
80 let fd = self.as_raw_fd();
81 let addr = sockaddr(addr)?;
82
83 unsafe {
84 let mut msg: msghdr = zeroed();
85 msg.msg_name = addr.as_ptr() as _;
86 msg.msg_namelen = addr.len() as _;
87 msg.msg_iov = data.as_ptr() as *mut _;
88 msg.msg_iovlen = data.len() as _;
89
90 if !ctrl.is_empty() {
91 msg.msg_control = ctrl.as_ptr() as *mut _;
92 msg.msg_controllen = ctrl.len() as _;
93 }
94
95 match libc::sendmsg(fd, &msg, 0) {
96 n if n >= 0 => Ok(n as usize),
97 _ => Err(Error::last_os_error()),
98 }
99 }
100 }
101
102 pub fn get_sockopt<O: Opt>(&self, level: Level, name: Name) -> Result<O> {
103 let fd = self.as_raw_fd();
104
105 let mut val = O::default();
106 let mut len = size_of::<O>() as socklen_t;
107
108 let ptr = &mut val as *mut _ as *mut _;
109 let len = &mut len as *mut _;
110
111 unsafe {
112 let level = transmute(level);
113 let name = transmute(name);
114 match libc::getsockopt(fd, level, name, ptr, len) {
115 0 => Ok(val),
116 _ => Err(Error::last_os_error()),
117 }
118 }
119 }
120
121 pub fn set_sockopt<O: Opt>(&self, level: Level, name: Name, value: &O) -> Result<()> {
122 let fd = self.as_raw_fd();
123 let ptr = value as *const _ as *const _;
124 let len = size_of::<O>() as socklen_t;
125
126 unsafe {
127 let level = transmute(level);
128 let name = transmute(name);
129 match libc::setsockopt(fd, level, name, ptr, len) {
130 0 => Ok(()),
131 _ => Err(Error::last_os_error()),
132 }
133 }
134 }
135
136 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
137 self.sys.set_nonblocking(nonblocking)
138 }
139}
140
141impl AsRawFd for RawSocket {
142 fn as_raw_fd(&self) -> RawFd {
143 self.sys.as_raw_fd()
144 }
145}
146
147fn sockaddr<A: ToSocketAddrs>(addr: A) -> Result<SockAddr> {
148 match addr.to_socket_addrs()?.next() {
149 Some(addr) => Ok(SockAddr::from(addr)),
150 None => Err(Error::new(ErrorKind::InvalidInput, "invalid socket address")),
151 }
152}
153
154fn socketaddr(addr: &SockAddr) -> Result<SocketAddr> {
155 match addr.family() as c_int {
156 AF_INET => Ok(addr.as_inet().expect("AF_INET addr").into()),
157 AF_INET6 => Ok(addr.as_inet6().expect("AF_INET6 addr").into()),
158 _ => Err(Error::new(ErrorKind::Other, "unknown address type")),
159 }
160}