1use filedesc::FileDesc;
2use std::io::{IoSlice, IoSliceMut};
3use std::os::raw::{c_int, c_void};
4use std::os::unix::io::{RawFd, AsRawFd, IntoRawFd, FromRawFd};
5
6use crate::AsSocketAddress;
7use crate::ancillary::SocketAncillary;
8
9pub struct Socket<Address> {
11 fd: FileDesc,
12 _address: std::marker::PhantomData<fn() -> Address>,
13}
14
15#[cfg(not(any(target_os = "apple", target_os = "solaris")))]
16mod extra_flags {
17 pub const SENDMSG: std::os::raw::c_int = libc::MSG_NOSIGNAL;
18 pub const RECVMSG: std::os::raw::c_int = libc::MSG_CMSG_CLOEXEC;
19}
20
21#[cfg(any(target_os = "apple", target_os = "solaris"))]
22mod extra_flags {
23 pub const SENDMSG: std::os::raw::c_int = 0;
24 pub const RECVMSG: std::os::raw::c_int = 0;
25}
26
27impl<Address: AsSocketAddress> Socket<Address> {
28 fn wrap(fd: FileDesc) -> std::io::Result<Self> {
32 let wrapped = Self {
33 fd,
34 _address: std::marker::PhantomData,
35 };
36
37 #[cfg(target_os = "apple")]
38 wrapped.set_option(libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1 as c_int)?;
39
40 Ok(wrapped)
41 }
42
43 pub fn new(kind: c_int, protocol: c_int) -> std::io::Result<Self>
52 where
53 Address: crate::SpecificSocketAddress,
54 {
55 Self::new_generic(Address::static_family() as c_int, kind, protocol)
56 }
57
58 pub fn new_generic(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<Self> {
68 socket(domain, kind | libc::SOCK_CLOEXEC, protocol)
69 .or_else(|e| {
70 if e.raw_os_error() == Some(libc::EINVAL) {
72 let fd = socket(domain, kind, protocol)?;
73 fd.set_close_on_exec(true)?;
74 Ok(fd)
75 } else {
76 Err(e)
77 }
78 })
79 .and_then(Self::wrap)
80 }
81
82 pub fn pair(kind: c_int, protocol: c_int) -> std::io::Result<(Self, Self)>
91 where
92 Address: crate::SpecificSocketAddress,
93 {
94 Self::pair_generic(Address::static_family() as c_int, kind, protocol)
95 }
96
97 pub fn pair_generic(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<(Self, Self)> {
107 socketpair(domain, kind | libc::SOCK_CLOEXEC, protocol)
108 .or_else(|e| {
109 if e.raw_os_error() == Some(libc::EINVAL) {
111 let (a, b) = socketpair(domain, kind, protocol)?;
112 a.set_close_on_exec(true)?;
113 b.set_close_on_exec(true)?;
114 Ok((a, b))
115 } else {
116 Err(e)
117 }
118 })
119 .and_then(|(a, b)| {
120 Ok((Self::wrap(a)?, Self::wrap(b)?))
121 })
122 }
123
124 pub fn try_clone(&self) -> std::io::Result<Self> {
132 Ok(Self {
133 fd: self.fd.duplicate()?,
134 _address: std::marker::PhantomData,
135 })
136 }
137
138 pub unsafe fn from_raw_fd(fd: RawFd) -> Self {
144 Self {
145 fd: FileDesc::from_raw_fd(fd),
146 _address: std::marker::PhantomData,
147 }
148 }
149
150 pub fn as_raw_fd(&self) -> RawFd {
155 self.fd.as_raw_fd()
156 }
157
158 pub fn into_raw_fd(self) -> RawFd {
163 self.fd.into_raw_fd()
164 }
165
166 fn set_option<T: Copy>(&self, level: c_int, option: c_int, value: T) -> std::io::Result<()> {
170 unsafe {
171 let value = &value as *const T as *const c_void;
172 let length = std::mem::size_of::<T>() as libc::socklen_t;
173 check_ret(libc::setsockopt(self.as_raw_fd(), level, option, value, length))?;
174 Ok(())
175 }
176 }
177
178 fn get_option<T: Copy>(&self, level: c_int, option: c_int) -> std::io::Result<T> {
182 unsafe {
183 let mut output = std::mem::MaybeUninit::zeroed();
184 let output_ptr = output.as_mut_ptr() as *mut c_void;
185 let mut length = std::mem::size_of::<T>() as libc::socklen_t;
186 check_ret(libc::getsockopt(self.as_raw_fd(), level, option, output_ptr, &mut length))?;
187 assert_eq!(length, std::mem::size_of::<T>() as libc::socklen_t);
188 Ok(output.assume_init())
189 }
190 }
191
192 pub fn set_nonblocking(&self, non_blocking: bool) -> std::io::Result<()> {
194 self.set_option(libc::SOL_SOCKET, libc::O_NONBLOCK, bool_to_c_int(non_blocking))
195 }
196
197 pub fn get_nonblocking(&self) -> std::io::Result<bool> {
199 let raw: c_int = self.get_option(libc::SOL_SOCKET, libc::O_NONBLOCK)?;
200 Ok(raw != 0)
201 }
202
203 pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
208 let raw: c_int = self.get_option(libc::SOL_SOCKET, libc::SO_ERROR)?;
209 if raw == 0 {
210 Ok(None)
211 } else {
212 Ok(Some(std::io::Error::from_raw_os_error(raw)))
213 }
214 }
215
216 pub fn local_addr(&self) -> std::io::Result<Address> {
218 unsafe {
219 let mut address = std::mem::MaybeUninit::<Address>::zeroed();
220 let mut len = Address::max_len();
221 check_ret(libc::getsockname(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len))?;
222 Address::finalize(address, len)
223 }
224 }
225
226 pub fn peer_addr(&self) -> std::io::Result<Address> {
228 unsafe {
229 let mut address = std::mem::MaybeUninit::<Address>::zeroed();
230 let mut len = Address::max_len();
231 check_ret(libc::getpeername(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len))?;
232 Address::finalize(address, len)
233 }
234 }
235
236 pub fn connect(&self, address: &Address) -> std::io::Result<()> {
241 unsafe {
242 check_ret(libc::connect(self.as_raw_fd(), address.as_sockaddr(), address.len()))?;
243 Ok(())
244 }
245 }
246
247 pub fn bind(&self, address: &Address) -> std::io::Result<()> {
252 unsafe {
253 check_ret(libc::bind(self.as_raw_fd(), address.as_sockaddr(), address.len()))?;
254 Ok(())
255 }
256 }
257
258 pub fn listen(&self, backlog: c_int) -> std::io::Result<()> {
266 unsafe {
267 check_ret(libc::listen(self.as_raw_fd(), backlog))?;
268 Ok(())
269 }
270 }
271
272 pub fn accept(&self) -> std::io::Result<(Self, Address)> {
280 unsafe {
281 let mut address = std::mem::MaybeUninit::zeroed();
282 let mut len = Address::max_len();
283 let fd = check_ret(libc::accept4(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len, libc::SOCK_CLOEXEC))?;
284 let socket = Self::wrap(FileDesc::from_raw_fd(fd))?;
285 let address = Address::finalize(address, len)?;
286 Ok((socket, address))
287 }
288 }
289
290 pub fn send(&self, data: &[u8], flags: c_int) -> std::io::Result<usize> {
296 unsafe {
297 let data_ptr = data.as_ptr() as *const c_void;
298 let transferred = check_ret_isize(libc::send(self.as_raw_fd(), data_ptr, data.len(), flags | extra_flags::SENDMSG))?;
299 Ok(transferred as usize)
300 }
301 }
302
303 pub fn send_to(&self, data: &[u8], address: &Address, flags: c_int) -> std::io::Result<usize> {
311 unsafe {
312 let data_ptr = data.as_ptr() as *const c_void;
313 let transferred = check_ret_isize(libc::sendto(
314 self.as_raw_fd(),
315 data_ptr,
316 data.len(),
317 flags | extra_flags::SENDMSG,
318 address.as_sockaddr(), address.len()
319 ))?;
320 Ok(transferred as usize)
321 }
322 }
323
324 pub fn send_msg(&self, data: &[IoSlice], cdata: Option<&[u8]>, flags: c_int) -> std::io::Result<usize> {
330 unsafe {
331 let mut header = std::mem::zeroed::<libc::msghdr>();
332 header.msg_iov = data.as_ptr() as *mut libc::iovec;
333 header.msg_iovlen = data.len();
334 header.msg_control = cdata.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *mut c_void;
335 header.msg_controllen = cdata.map(|x| x.len()).unwrap_or(0);
336
337 let ret = check_ret_isize(libc::sendmsg(self.as_raw_fd(), &header, flags | extra_flags::SENDMSG))?;
338 Ok(ret as usize)
339 }
340 }
341
342 pub fn send_msg_to(&self, address: &Address, data: &[IoSlice], cdata: Option<&[u8]>, flags: c_int) -> std::io::Result<usize> {
350 unsafe {
351 let mut header = std::mem::zeroed::<libc::msghdr>();
352 header.msg_name = address.as_sockaddr() as *mut c_void;
353 header.msg_namelen = address.len();
354 header.msg_iov = data.as_ptr() as *mut libc::iovec;
355 header.msg_iovlen = data.len();
356 header.msg_control = cdata.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *mut c_void;
357 header.msg_controllen = cdata.map(|x| x.len()).unwrap_or(0);
358
359 let ret = check_ret_isize(libc::sendmsg(self.as_raw_fd(), &header, flags | extra_flags::SENDMSG))?;
360 Ok(ret as usize)
361 }
362 }
363
364 pub fn recv(&self, buffer: &mut [u8], flags: c_int) -> std::io::Result<usize> {
370 unsafe {
371 let buffer_ptr = buffer.as_mut_ptr() as *mut c_void;
372 let transferred = check_ret_isize(libc::recv(self.as_raw_fd(), buffer_ptr, buffer.len(), flags | extra_flags::RECVMSG))?;
373 Ok(transferred as usize)
374 }
375 }
376
377 pub fn recv_from(&self, buffer: &mut [u8], flags: c_int) -> std::io::Result<(Address, usize)> {
383 unsafe {
384 let buffer_ptr = buffer.as_mut_ptr() as *mut c_void;
385 let mut address = std::mem::MaybeUninit::zeroed();
386 let mut address_len = Address::max_len();
387 let transferred = check_ret_isize(libc::recvfrom(
388 self.as_raw_fd(),
389 buffer_ptr,
390 buffer.len(),
391 flags,
392 Address::as_sockaddr_mut(&mut address),
393 &mut address_len
394 ))?;
395
396 let address = Address::finalize(address, address_len)?;
397 Ok((address, transferred as usize))
398 }
399 }
400
401 pub fn recv_msg(&self, data: &[IoSliceMut], cdata: &mut SocketAncillary, flags: c_int) -> std::io::Result<(usize, c_int)> {
410 let (cdata_buf, cdata_len) = if cdata.capacity() == 0 {
411 (std::ptr::null_mut(), 0)
412 } else {
413 (cdata.buffer.as_mut_ptr(), cdata.capacity())
414 };
415
416 unsafe {
417 let mut header = std::mem::zeroed::<libc::msghdr>();
418 header.msg_iov = data.as_ptr() as *mut libc::iovec;
419 header.msg_iovlen = data.len();
420 header.msg_control = cdata_buf as *mut c_void;
421 header.msg_controllen = cdata_len;
422
423 let ret = check_ret_isize(libc::recvmsg(self.as_raw_fd(), &mut header, flags | extra_flags::RECVMSG))?;
424
425 cdata.length = header.msg_controllen as usize;
426 cdata.truncated = header.msg_flags & libc::MSG_CTRUNC != 0;
427 Ok((ret as usize, header.msg_flags))
428 }
429 }
430
431 pub fn recv_msg_from(&self, data: &[IoSliceMut], cdata: &mut SocketAncillary, flags: c_int) -> std::io::Result<(Address, usize, c_int)> {
441 let (cdata_buf, cdata_len) = if cdata.capacity() == 0 {
442 (std::ptr::null_mut(), 0)
443 } else {
444 (cdata.buffer.as_mut_ptr(), cdata.capacity())
445 };
446
447 unsafe {
448 let mut address = std::mem::MaybeUninit::zeroed();
449 let mut header = std::mem::zeroed::<libc::msghdr>();
450 header.msg_name = Address::as_sockaddr_mut(&mut address) as *mut c_void;
451 header.msg_namelen = Address::max_len();
452 header.msg_iov = data.as_ptr() as *mut libc::iovec;
453 header.msg_iovlen = data.len();
454 header.msg_control = cdata_buf as *mut c_void;
455 header.msg_controllen = cdata_len;
456
457 let ret = check_ret_isize(libc::recvmsg(self.as_raw_fd(), &mut header, flags | extra_flags::RECVMSG))?;
458 let address = Address::finalize(address, header.msg_namelen)?;
459 cdata.length = header.msg_controllen as usize;
460 cdata.truncated = header.msg_flags & libc::MSG_CTRUNC != 0;
461 Ok((address, ret as usize, header.msg_flags))
462 }
463 }
464}
465
466impl<Address: AsSocketAddress> FromRawFd for Socket<Address> {
467 unsafe fn from_raw_fd(fd: RawFd) -> Self {
468 Self::from_raw_fd(fd)
469 }
470}
471
472impl<Address: AsSocketAddress> AsRawFd for Socket<Address> {
473 fn as_raw_fd(&self) -> RawFd {
474 self.as_raw_fd()
475 }
476}
477
478impl<Address: AsSocketAddress> AsRawFd for &'_ Socket<Address> {
479 fn as_raw_fd(&self) -> RawFd {
480 (*self).as_raw_fd()
481 }
482}
483
484impl<Address: AsSocketAddress> IntoRawFd for Socket<Address> {
485 fn into_raw_fd(self) -> RawFd {
486 self.into_raw_fd()
487 }
488}
489
490fn check_ret(ret: c_int) -> std::io::Result<c_int> {
495 if ret == -1 {
496 Err(std::io::Error::last_os_error())
497 } else {
498 Ok(ret)
499 }
500}
501
502fn check_ret_isize(ret: isize) -> std::io::Result<isize> {
507 if ret == -1 {
508 Err(std::io::Error::last_os_error())
509 } else {
510 Ok(ret)
511 }
512}
513
514fn socket(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<FileDesc> {
516 unsafe {
517 let fd = check_ret(libc::socket(domain, kind, protocol))?;
518 Ok(FileDesc::from_raw_fd(fd))
519 }
520}
521
522fn socketpair(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<(FileDesc, FileDesc)> {
524 unsafe {
525 let mut fds = [0; 2];
526 check_ret(libc::socketpair(domain, kind, protocol, fds.as_mut_ptr()))?;
527 Ok((
528 FileDesc::from_raw_fd(fds[0]),
529 FileDesc::from_raw_fd(fds[1]),
530 ))
531 }
532}
533
534fn bool_to_c_int(value: bool) -> c_int {
535 if value {
536 1
537 } else {
538 0
539 }
540}