nex_socket/socket/
sync_impl.rs

1use crate::socket::to_socket_protocol;
2use crate::socket::{IpVersion, SocketOption};
3use socket2::{SockAddr, Socket as SystemSocket};
4use std::io;
5use std::mem::MaybeUninit;
6use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket};
7use std::sync::Arc;
8use std::time::Duration;
9
10/// Socket. Provides cross-platform adapter for system socket.
11#[derive(Clone, Debug)]
12pub struct Socket {
13    inner: Arc<SystemSocket>,
14}
15
16impl Socket {
17    /// Constructs a new Socket.
18    pub fn new(socket_option: SocketOption) -> io::Result<Socket> {
19        let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
20            SystemSocket::new(
21                socket_option.ip_version.to_domain(),
22                socket_option.socket_type.to_type(),
23                Some(to_socket_protocol(protocol)),
24            )?
25        } else {
26            SystemSocket::new(
27                socket_option.ip_version.to_domain(),
28                socket_option.socket_type.to_type(),
29                None,
30            )?
31        };
32        if socket_option.non_blocking {
33            socket.set_nonblocking(true)?;
34        }
35        Ok(Socket {
36            inner: Arc::new(socket),
37        })
38    }
39    /// Bind socket to address.
40    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
41        let addr: SockAddr = SockAddr::from(addr);
42        self.inner.bind(&addr)
43    }
44    /// Send packet.
45    pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
46        match self.inner.send(buf) {
47            Ok(n) => Ok(n),
48            Err(e) => Err(e),
49        }
50    }
51    /// Send packet to target.
52    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
53        let target: SockAddr = SockAddr::from(target);
54        match self.inner.send_to(buf, &target) {
55            Ok(n) => Ok(n),
56            Err(e) => Err(e),
57        }
58    }
59    /// Receive packet.
60    pub fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
61        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
62        match self.inner.recv(recv_buf) {
63            Ok(result) => Ok(result),
64            Err(e) => Err(e),
65        }
66    }
67    /// Receive packet with sender address.
68    pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
69        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
70        match self.inner.recv_from(recv_buf) {
71            Ok(result) => {
72                let (n, addr) = result;
73                match addr.as_socket() {
74                    Some(addr) => return Ok((n, addr)),
75                    None => {
76                        return Err(io::Error::new(
77                            io::ErrorKind::Other,
78                            "Invalid socket address",
79                        ))
80                    }
81                }
82            }
83            Err(e) => Err(e),
84        }
85    }
86    /// Write data to the socket and send to the target.
87    /// Return how many bytes were written.
88    pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
89        match self.inner.send(buf) {
90            Ok(n) => Ok(n),
91            Err(e) => Err(e),
92        }
93    }
94    /// Attempts to write an entire buffer into this writer.
95    pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
96        let mut offset = 0;
97        while offset < buf.len() {
98            match self.inner.send(&buf[offset..]) {
99                Ok(n) => offset += n,
100                Err(e) => return Err(e),
101            }
102        }
103        Ok(())
104    }
105    /// Read data from the socket.
106    /// Return how many bytes were read.
107    pub fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
108        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
109        match self.inner.recv(recv_buf) {
110            Ok(result) => Ok(result),
111            Err(e) => Err(e),
112        }
113    }
114    /// Read all bytes until EOF in this source, placing them into buf.
115    pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
116        let mut total = 0;
117        loop {
118            let mut recv_buf = Vec::new();
119            match self.receive(&mut recv_buf) {
120                Ok(n) => {
121                    if n == 0 {
122                        break;
123                    }
124                    total += n;
125                    buf.extend_from_slice(&recv_buf[..n]);
126                }
127                Err(e) => return Err(e),
128            }
129        }
130        Ok(total)
131    }
132    /// Read all bytes until EOF in this source, placing them into buf.
133    /// This ignore io::Error on read_to_end because it is expected when reading response.
134    /// If no response is received, and io::Error is occurred, return Err.
135    pub fn read_to_end_timeout(&self, buf: &mut Vec<u8>, timeout: Duration) -> io::Result<usize> {
136        // Set timeout
137        self.inner.set_read_timeout(Some(timeout))?;
138        let mut total = 0;
139        loop {
140            let mut recv_buf = Vec::new();
141            match self.receive(&mut recv_buf) {
142                Ok(n) => {
143                    if n == 0 {
144                        return Ok(total);
145                    }
146                    total += n;
147                    buf.extend_from_slice(&recv_buf[..n]);
148                }
149                Err(e) => {
150                    if e.kind() == io::ErrorKind::WouldBlock {
151                        return Ok(total);
152                    }
153                    return Err(e);
154                }
155            }
156        }
157    }
158    /// Get TTL or Hop Limit.
159    pub fn ttl(&self, ip_version: IpVersion) -> io::Result<u32> {
160        match ip_version {
161            IpVersion::V4 => self.inner.ttl(),
162            IpVersion::V6 => self.inner.unicast_hops_v6(),
163        }
164    }
165    /// Set TTL or Hop Limit.
166    pub fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
167        match ip_version {
168            IpVersion::V4 => self.inner.set_ttl(ttl),
169            IpVersion::V6 => self.inner.set_unicast_hops_v6(ttl),
170        }
171    }
172    /// Get the value of the IP_TOS option for this socket.
173    pub fn tos(&self) -> io::Result<u32> {
174        self.inner.tos()
175    }
176    /// Set the value of the IP_TOS option for this socket.
177    pub fn set_tos(&self, tos: u32) -> io::Result<()> {
178        self.inner.set_tos(tos)
179    }
180    /// Get the value of the IP_RECVTOS option for this socket.
181    pub fn receive_tos(&self) -> io::Result<bool> {
182        self.inner.recv_tos()
183    }
184    /// Set the value of the IP_RECVTOS option for this socket.
185    pub fn set_receive_tos(&self, receive_tos: bool) -> io::Result<()> {
186        self.inner.set_recv_tos(receive_tos)
187    }
188    /// Initiate TCP connection.
189    pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
190        let addr: SockAddr = SockAddr::from(*addr);
191        self.inner.connect(&addr)
192    }
193    /// Initiate a connection on this socket to the specified address, only only waiting for a certain period of time for the connection to be established.
194    /// The non-blocking state of the socket is overridden by this function.
195    pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
196        let addr: SockAddr = SockAddr::from(*addr);
197        self.inner.connect_timeout(&addr, timeout)
198    }
199    /// Listen TCP connection.
200    pub fn listen(&self, backlog: i32) -> io::Result<()> {
201        self.inner.listen(backlog)
202    }
203    /// Accept TCP connection.
204    pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
205        match self.inner.accept() {
206            Ok((socket, addr)) => Ok((
207                Socket {
208                    inner: Arc::new(socket),
209                },
210                addr.as_socket().unwrap(),
211            )),
212            Err(e) => Err(e),
213        }
214    }
215    /// Get local address.
216    pub fn local_addr(&self) -> io::Result<SocketAddr> {
217        match self.inner.local_addr() {
218            Ok(addr) => Ok(addr.as_socket().unwrap()),
219            Err(e) => Err(e),
220        }
221    }
222    /// Get peer address.
223    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
224        match self.inner.peer_addr() {
225            Ok(addr) => Ok(addr.as_socket().unwrap()),
226            Err(e) => Err(e),
227        }
228    }
229    /// Get type of the socket.
230    pub fn socket_type(&self) -> io::Result<crate::socket::SocketType> {
231        match self.inner.r#type() {
232            Ok(socktype) => Ok(crate::socket::SocketType::from_type(socktype)),
233            Err(e) => Err(e),
234        }
235    }
236    /// Create a new socket with the same configuration and bound to the same address.
237    pub fn try_clone(&self) -> io::Result<Socket> {
238        match self.inner.try_clone() {
239            Ok(socket) => Ok(Socket {
240                inner: Arc::new(socket),
241            }),
242            Err(e) => Err(e),
243        }
244    }
245    /// Returns true if this socket is set to nonblocking mode, false otherwise.
246    #[cfg(not(target_os = "windows"))]
247    pub fn is_nonblocking(&self) -> io::Result<bool> {
248        self.inner.nonblocking()
249    }
250    /// Set non-blocking mode.
251    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
252        self.inner.set_nonblocking(nonblocking)
253    }
254    /// Shutdown TCP connection.
255    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
256        self.inner.shutdown(how)
257    }
258    /// Get the value of the SO_BROADCAST option for this socket.
259    pub fn is_broadcast(&self) -> io::Result<bool> {
260        self.inner.broadcast()
261    }
262    /// Set the value of the `SO_BROADCAST` option for this socket.
263    ///
264    /// When enabled, this socket is allowed to send packets to a broadcast address.
265    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
266        self.inner.set_broadcast(broadcast)
267    }
268    /// Get the value of the `SO_ERROR` option on this socket.
269    pub fn get_error(&self) -> io::Result<Option<io::Error>> {
270        self.inner.take_error()
271    }
272    /// Get the value of the `SO_KEEPALIVE` option on this socket.
273    pub fn keepalive(&self) -> io::Result<bool> {
274        self.inner.keepalive()
275    }
276    /// Set value for the `SO_KEEPALIVE` option on this socket.
277    ///
278    /// Enable sending of keep-alive messages on connection-oriented sockets.
279    pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
280        self.inner.set_keepalive(keepalive)
281    }
282    /// Get the value of the SO_LINGER option on this socket.
283    pub fn linger(&self) -> io::Result<Option<Duration>> {
284        self.inner.linger()
285    }
286    /// Set value for the SO_LINGER option on this socket.
287    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
288        self.inner.set_linger(dur)
289    }
290    /// Get the value of the `SO_RCVBUF` option on this socket.
291    pub fn receive_buffer_size(&self) -> io::Result<usize> {
292        self.inner.recv_buffer_size()
293    }
294    /// Set value for the `SO_RCVBUF` option on this socket.
295    ///
296    /// Changes the size of the operating system's receive buffer associated with the socket.
297    pub fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
298        self.inner.set_recv_buffer_size(size)
299    }
300    /// Get value for the SO_RCVTIMEO option on this socket.
301    pub fn receive_timeout(&self) -> io::Result<Option<Duration>> {
302        self.inner.read_timeout()
303    }
304    /// Set value for the `SO_RCVTIMEO` option on this socket.
305    pub fn set_receive_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
306        self.inner.set_read_timeout(duration)
307    }
308    /// Get value for the `SO_REUSEADDR` option on this socket.
309    pub fn reuse_address(&self) -> io::Result<bool> {
310        self.inner.reuse_address()
311    }
312    /// Set value for the `SO_REUSEADDR` option on this socket.
313    ///
314    /// This indicates that futher calls to `bind` may allow reuse of local addresses.
315    pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
316        self.inner.set_reuse_address(reuse)
317    }
318    /// Get value for the `SO_SNDBUF` option on this socket.
319    pub fn send_buffer_size(&self) -> io::Result<usize> {
320        self.inner.send_buffer_size()
321    }
322    /// Set value for the `SO_SNDBUF` option on this socket.
323    ///
324    /// Changes the size of the operating system's send buffer associated with the socket.
325    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
326        self.inner.set_send_buffer_size(size)
327    }
328    /// Get value for the `SO_SNDTIMEO` option on this socket.
329    pub fn send_timeout(&self) -> io::Result<Option<Duration>> {
330        self.inner.write_timeout()
331    }
332    /// Set value for the `SO_SNDTIMEO` option on this socket.
333    ///
334    /// If `timeout` is `None`, then `write` and `send` calls will block indefinitely.
335    pub fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
336        self.inner.set_write_timeout(duration)
337    }
338    /// Get the value of the IP_HDRINCL option on this socket.
339    pub fn is_ip_header_included(&self) -> io::Result<bool> {
340        self.inner.header_included()
341    }
342    /// Set the value of the `IP_HDRINCL` option on this socket.
343    pub fn set_ip_header_included(&self, include: bool) -> io::Result<()> {
344        self.inner.set_header_included(include)
345    }
346    /// Get the value of the TCP_NODELAY option on this socket.
347    pub fn nodelay(&self) -> io::Result<bool> {
348        self.inner.nodelay()
349    }
350    /// Set the value of the `TCP_NODELAY` option on this socket.
351    ///
352    /// If set, segments are always sent as soon as possible, even if there is only a small amount of data.
353    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
354        self.inner.set_nodelay(nodelay)
355    }
356    /// Get TCP Stream
357    /// This function will consume the socket and return a new std::net::TcpStream.
358    pub fn into_tcp_stream(self) -> io::Result<TcpStream> {
359        match Arc::try_unwrap(self.inner) {
360            Ok(socket) => Ok(socket.into()),
361            Err(_) => Err(io::Error::new(
362                io::ErrorKind::Other,
363                "Failed to unwrap socket",
364            )),
365        }
366    }
367    /// Get TCP Listener
368    /// This function will consume the socket and return a new std::net::TcpListener.
369    pub fn into_tcp_listener(self) -> io::Result<TcpListener> {
370        match Arc::try_unwrap(self.inner) {
371            Ok(socket) => Ok(socket.into()),
372            Err(_) => Err(io::Error::new(
373                io::ErrorKind::Other,
374                "Failed to unwrap socket",
375            )),
376        }
377    }
378    /// Get UDP Socket
379    /// This function will consume the socket and return a new std::net::UdpSocket.
380    pub fn into_udp_socket(self) -> io::Result<UdpSocket> {
381        match Arc::try_unwrap(self.inner) {
382            Ok(socket) => Ok(socket.into()),
383            Err(_) => Err(io::Error::new(
384                io::ErrorKind::Other,
385                "Failed to unwrap socket",
386            )),
387        }
388    }
389}