1use crate::socket::to_socket_protocol;
2use crate::socket::{IpVersion, SocketOption};
3use socket2::{SockAddr, Socket as SystemSocket};
4use std::io;
5use std::mem::MaybeUninit;
6use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket};
7use std::sync::Arc;
8use std::time::Duration;
9
10#[derive(Clone, Debug)]
12pub struct Socket {
13 inner: Arc<SystemSocket>,
14}
15
16impl Socket {
17 pub fn new(socket_option: SocketOption) -> io::Result<Socket> {
19 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
20 SystemSocket::new(
21 socket_option.ip_version.to_domain(),
22 socket_option.socket_type.to_type(),
23 Some(to_socket_protocol(protocol)),
24 )?
25 } else {
26 SystemSocket::new(
27 socket_option.ip_version.to_domain(),
28 socket_option.socket_type.to_type(),
29 None,
30 )?
31 };
32 if socket_option.non_blocking {
33 socket.set_nonblocking(true)?;
34 }
35 Ok(Socket {
36 inner: Arc::new(socket),
37 })
38 }
39 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
41 let addr: SockAddr = SockAddr::from(addr);
42 self.inner.bind(&addr)
43 }
44 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
46 match self.inner.send(buf) {
47 Ok(n) => Ok(n),
48 Err(e) => Err(e),
49 }
50 }
51 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
53 let target: SockAddr = SockAddr::from(target);
54 match self.inner.send_to(buf, &target) {
55 Ok(n) => Ok(n),
56 Err(e) => Err(e),
57 }
58 }
59 pub fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
61 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
62 match self.inner.recv(recv_buf) {
63 Ok(result) => Ok(result),
64 Err(e) => Err(e),
65 }
66 }
67 pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
69 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
70 match self.inner.recv_from(recv_buf) {
71 Ok(result) => {
72 let (n, addr) = result;
73 match addr.as_socket() {
74 Some(addr) => return Ok((n, addr)),
75 None => {
76 return Err(io::Error::new(
77 io::ErrorKind::Other,
78 "Invalid socket address",
79 ))
80 }
81 }
82 }
83 Err(e) => Err(e),
84 }
85 }
86 pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
89 match self.inner.send(buf) {
90 Ok(n) => Ok(n),
91 Err(e) => Err(e),
92 }
93 }
94 pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
96 let mut offset = 0;
97 while offset < buf.len() {
98 match self.inner.send(&buf[offset..]) {
99 Ok(n) => offset += n,
100 Err(e) => return Err(e),
101 }
102 }
103 Ok(())
104 }
105 pub fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
108 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
109 match self.inner.recv(recv_buf) {
110 Ok(result) => Ok(result),
111 Err(e) => Err(e),
112 }
113 }
114 pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
116 let mut total = 0;
117 loop {
118 let mut recv_buf = Vec::new();
119 match self.receive(&mut recv_buf) {
120 Ok(n) => {
121 if n == 0 {
122 break;
123 }
124 total += n;
125 buf.extend_from_slice(&recv_buf[..n]);
126 }
127 Err(e) => return Err(e),
128 }
129 }
130 Ok(total)
131 }
132 pub fn read_to_end_timeout(&self, buf: &mut Vec<u8>, timeout: Duration) -> io::Result<usize> {
136 self.inner.set_read_timeout(Some(timeout))?;
138 let mut total = 0;
139 loop {
140 let mut recv_buf = Vec::new();
141 match self.receive(&mut recv_buf) {
142 Ok(n) => {
143 if n == 0 {
144 return Ok(total);
145 }
146 total += n;
147 buf.extend_from_slice(&recv_buf[..n]);
148 }
149 Err(e) => {
150 if e.kind() == io::ErrorKind::WouldBlock {
151 return Ok(total);
152 }
153 return Err(e);
154 }
155 }
156 }
157 }
158 pub fn ttl(&self, ip_version: IpVersion) -> io::Result<u32> {
160 match ip_version {
161 IpVersion::V4 => self.inner.ttl(),
162 IpVersion::V6 => self.inner.unicast_hops_v6(),
163 }
164 }
165 pub fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
167 match ip_version {
168 IpVersion::V4 => self.inner.set_ttl(ttl),
169 IpVersion::V6 => self.inner.set_unicast_hops_v6(ttl),
170 }
171 }
172 pub fn tos(&self) -> io::Result<u32> {
174 self.inner.tos()
175 }
176 pub fn set_tos(&self, tos: u32) -> io::Result<()> {
178 self.inner.set_tos(tos)
179 }
180 pub fn receive_tos(&self) -> io::Result<bool> {
182 self.inner.recv_tos()
183 }
184 pub fn set_receive_tos(&self, receive_tos: bool) -> io::Result<()> {
186 self.inner.set_recv_tos(receive_tos)
187 }
188 pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
190 let addr: SockAddr = SockAddr::from(*addr);
191 self.inner.connect(&addr)
192 }
193 pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
196 let addr: SockAddr = SockAddr::from(*addr);
197 self.inner.connect_timeout(&addr, timeout)
198 }
199 pub fn listen(&self, backlog: i32) -> io::Result<()> {
201 self.inner.listen(backlog)
202 }
203 pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
205 match self.inner.accept() {
206 Ok((socket, addr)) => Ok((
207 Socket {
208 inner: Arc::new(socket),
209 },
210 addr.as_socket().unwrap(),
211 )),
212 Err(e) => Err(e),
213 }
214 }
215 pub fn local_addr(&self) -> io::Result<SocketAddr> {
217 match self.inner.local_addr() {
218 Ok(addr) => Ok(addr.as_socket().unwrap()),
219 Err(e) => Err(e),
220 }
221 }
222 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
224 match self.inner.peer_addr() {
225 Ok(addr) => Ok(addr.as_socket().unwrap()),
226 Err(e) => Err(e),
227 }
228 }
229 pub fn socket_type(&self) -> io::Result<crate::socket::SocketType> {
231 match self.inner.r#type() {
232 Ok(socktype) => Ok(crate::socket::SocketType::from_type(socktype)),
233 Err(e) => Err(e),
234 }
235 }
236 pub fn try_clone(&self) -> io::Result<Socket> {
238 match self.inner.try_clone() {
239 Ok(socket) => Ok(Socket {
240 inner: Arc::new(socket),
241 }),
242 Err(e) => Err(e),
243 }
244 }
245 #[cfg(not(target_os = "windows"))]
247 pub fn is_nonblocking(&self) -> io::Result<bool> {
248 self.inner.nonblocking()
249 }
250 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
252 self.inner.set_nonblocking(nonblocking)
253 }
254 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
256 self.inner.shutdown(how)
257 }
258 pub fn is_broadcast(&self) -> io::Result<bool> {
260 self.inner.broadcast()
261 }
262 pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
266 self.inner.set_broadcast(broadcast)
267 }
268 pub fn get_error(&self) -> io::Result<Option<io::Error>> {
270 self.inner.take_error()
271 }
272 pub fn keepalive(&self) -> io::Result<bool> {
274 self.inner.keepalive()
275 }
276 pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
280 self.inner.set_keepalive(keepalive)
281 }
282 pub fn linger(&self) -> io::Result<Option<Duration>> {
284 self.inner.linger()
285 }
286 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
288 self.inner.set_linger(dur)
289 }
290 pub fn receive_buffer_size(&self) -> io::Result<usize> {
292 self.inner.recv_buffer_size()
293 }
294 pub fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
298 self.inner.set_recv_buffer_size(size)
299 }
300 pub fn receive_timeout(&self) -> io::Result<Option<Duration>> {
302 self.inner.read_timeout()
303 }
304 pub fn set_receive_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
306 self.inner.set_read_timeout(duration)
307 }
308 pub fn reuse_address(&self) -> io::Result<bool> {
310 self.inner.reuse_address()
311 }
312 pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
316 self.inner.set_reuse_address(reuse)
317 }
318 pub fn send_buffer_size(&self) -> io::Result<usize> {
320 self.inner.send_buffer_size()
321 }
322 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
326 self.inner.set_send_buffer_size(size)
327 }
328 pub fn send_timeout(&self) -> io::Result<Option<Duration>> {
330 self.inner.write_timeout()
331 }
332 pub fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
336 self.inner.set_write_timeout(duration)
337 }
338 pub fn is_ip_header_included(&self) -> io::Result<bool> {
340 self.inner.header_included()
341 }
342 pub fn set_ip_header_included(&self, include: bool) -> io::Result<()> {
344 self.inner.set_header_included(include)
345 }
346 pub fn nodelay(&self) -> io::Result<bool> {
348 self.inner.nodelay()
349 }
350 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
354 self.inner.set_nodelay(nodelay)
355 }
356 pub fn into_tcp_stream(self) -> io::Result<TcpStream> {
359 match Arc::try_unwrap(self.inner) {
360 Ok(socket) => Ok(socket.into()),
361 Err(_) => Err(io::Error::new(
362 io::ErrorKind::Other,
363 "Failed to unwrap socket",
364 )),
365 }
366 }
367 pub fn into_tcp_listener(self) -> io::Result<TcpListener> {
370 match Arc::try_unwrap(self.inner) {
371 Ok(socket) => Ok(socket.into()),
372 Err(_) => Err(io::Error::new(
373 io::ErrorKind::Other,
374 "Failed to unwrap socket",
375 )),
376 }
377 }
378 pub fn into_udp_socket(self) -> io::Result<UdpSocket> {
381 match Arc::try_unwrap(self.inner) {
382 Ok(socket) => Ok(socket.into()),
383 Err(_) => Err(io::Error::new(
384 io::ErrorKind::Other,
385 "Failed to unwrap socket",
386 )),
387 }
388 }
389}