posix_socket/
socket.rs

1use filedesc::FileDesc;
2use std::io::{IoSlice, IoSliceMut};
3use std::os::raw::{c_int, c_void};
4use std::os::unix::io::{RawFd, AsRawFd, IntoRawFd, FromRawFd};
5
6use crate::AsSocketAddress;
7use crate::ancillary::SocketAncillary;
8
9/// A POSIX socket.
10pub struct Socket<Address> {
11	fd: FileDesc,
12	_address: std::marker::PhantomData<fn() -> Address>,
13}
14
15#[cfg(not(any(target_os = "apple", target_os = "solaris")))]
16mod extra_flags {
17	pub const SENDMSG: std::os::raw::c_int = libc::MSG_NOSIGNAL;
18	pub const RECVMSG: std::os::raw::c_int = libc::MSG_CMSG_CLOEXEC;
19}
20
21#[cfg(any(target_os = "apple", target_os = "solaris"))]
22mod extra_flags {
23	pub const SENDMSG: std::os::raw::c_int = 0;
24	pub const RECVMSG: std::os::raw::c_int = 0;
25}
26
27impl<Address: AsSocketAddress> Socket<Address> {
28	/// Wrap a file descriptor in a Socket.
29	///
30	/// On Apple systems, this sets the SO_NOSIGPIPE option to prevent SIGPIPE signals.
31	fn wrap(fd: FileDesc) -> std::io::Result<Self> {
32		let wrapped = Self {
33			fd,
34			_address: std::marker::PhantomData,
35		};
36
37		#[cfg(target_os = "apple")]
38		wrapped.set_option(libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1 as c_int)?;
39
40		Ok(wrapped)
41	}
42
43	/// Create a new socket with the specified type and protocol.
44	///
45	/// The domain is taken from the `Address` type.
46	///
47	/// The created socket has the `close-on-exec` flag set.
48	/// The flag will be set atomically when the socket is created if the platform supports it.
49	///
50	/// See `man socket` for more information.
51	pub fn new(kind: c_int, protocol: c_int) -> std::io::Result<Self>
52	where
53		Address: crate::SpecificSocketAddress,
54	{
55		Self::new_generic(Address::static_family() as c_int, kind, protocol)
56	}
57
58	/// Create a new socket with the specified domain, type and protocol.
59	///
60	/// Unless you are working with generic socket addresses,
61	/// you should normally prefer `Self::new`.
62	///
63	/// The created socket has the `close-on-exec` flag set.
64	/// The flag will be set atomically when the socket is created if the platform supports it.
65	///
66	/// See `man socket` for more information.
67	pub fn new_generic(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<Self> {
68		socket(domain, kind | libc::SOCK_CLOEXEC, protocol)
69			.or_else(|e| {
70				// Fall back to setting close-on-exec after creation if SOCK_CLOEXEC is not supported.
71				if e.raw_os_error() == Some(libc::EINVAL) {
72					let fd = socket(domain, kind, protocol)?;
73					fd.set_close_on_exec(true)?;
74					Ok(fd)
75				} else {
76					Err(e)
77				}
78			})
79			.and_then(Self::wrap)
80	}
81
82	/// Create a connected pair of socket with the specified type and protocol.
83	///
84	/// The domain is taken from the `Address` type.
85	///
86	/// The created sockets have the `close-on-exec` flag set.
87	/// The flag will be set atomically when the sockets are created if the platform supports it.
88	///
89	/// See `man socketpair` and `man socket` for more information.
90	pub fn pair(kind: c_int, protocol: c_int) -> std::io::Result<(Self, Self)>
91	where
92		Address: crate::SpecificSocketAddress,
93	{
94		Self::pair_generic(Address::static_family() as c_int, kind, protocol)
95	}
96
97	/// Create a connected pair of socket with the specified domain, type and protocol.
98	///
99	/// Unless you are working with generic socket addresses,
100	/// you should normally prefer `Self::pair`.
101	///
102	/// The created sockets have the `close-on-exec` flag set.
103	/// The flag will be set atomically when the sockets are created if the platform supports it.
104	///
105	/// See `man socketpair` and `man socket` for more information.
106	pub fn pair_generic(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<(Self, Self)> {
107		socketpair(domain, kind | libc::SOCK_CLOEXEC, protocol)
108			.or_else(|e| {
109				// Fall back to setting close-on-exec after creation if SOCK_CLOEXEC is not supported.
110				if e.raw_os_error() == Some(libc::EINVAL) {
111					let (a, b) = socketpair(domain, kind, protocol)?;
112					a.set_close_on_exec(true)?;
113					b.set_close_on_exec(true)?;
114					Ok((a, b))
115				} else {
116					Err(e)
117				}
118			})
119			.and_then(|(a, b)| {
120				Ok((Self::wrap(a)?, Self::wrap(b)?))
121			})
122	}
123
124	/// Try to clone the socket.
125	///
126	/// This is implemented by duplicating the file descriptor.
127	/// The returned [`Socket`] refers to the same kernel object.
128	///
129	/// The underlying file descriptor of the new socket will have the `close-on-exec` flag set.
130	/// If the platform supports it, the flag will be set atomically when the file descriptor is duplicated.
131	pub fn try_clone(&self) -> std::io::Result<Self> {
132		Ok(Self {
133			fd: self.fd.duplicate()?,
134			_address: std::marker::PhantomData,
135		})
136	}
137
138	/// Wrap a raw file descriptor in a [`Socket`].
139	///
140	/// This function sets no flags or options on the file descriptor or socket.
141	/// It is your own responsibility to make sure the close-on-exec flag is already set,
142	/// and that the `SO_NOSIGPIPE` option is set on Apple platforms.
143	pub unsafe fn from_raw_fd(fd: RawFd) -> Self {
144		Self {
145			fd: FileDesc::from_raw_fd(fd),
146			_address: std::marker::PhantomData,
147		}
148	}
149
150	/// Get the raw file descriptor.
151	///
152	/// This function does not release ownership of the underlying file descriptor.
153	/// The file descriptor will still be closed when the [`FileDesc`] is dropped.
154	pub fn as_raw_fd(&self) -> RawFd {
155		self.fd.as_raw_fd()
156	}
157
158	/// Release and get the raw file descriptor.
159	///
160	/// This function releases ownership of the underlying file descriptor.
161	/// The file descriptor will not be closed.
162	pub fn into_raw_fd(self) -> RawFd {
163		self.fd.into_raw_fd()
164	}
165
166	/// Set a socket option.
167	///
168	/// See `man setsockopt` for more information.
169	fn set_option<T: Copy>(&self, level: c_int, option: c_int, value: T) -> std::io::Result<()> {
170		unsafe {
171			let value = &value as *const T as *const c_void;
172			let length = std::mem::size_of::<T>() as libc::socklen_t;
173			check_ret(libc::setsockopt(self.as_raw_fd(), level, option, value, length))?;
174			Ok(())
175		}
176	}
177
178	/// Get the value of a socket option.
179	///
180	/// See `man getsockopt` for more information.
181	fn get_option<T: Copy>(&self, level: c_int, option: c_int) -> std::io::Result<T> {
182		unsafe {
183			let mut output = std::mem::MaybeUninit::zeroed();
184			let output_ptr = output.as_mut_ptr() as *mut c_void;
185			let mut length = std::mem::size_of::<T>() as libc::socklen_t;
186			check_ret(libc::getsockopt(self.as_raw_fd(), level, option, output_ptr, &mut length))?;
187			assert_eq!(length, std::mem::size_of::<T>() as libc::socklen_t);
188			Ok(output.assume_init())
189		}
190	}
191
192	/// Put the socket in blocking or non-blocking mode.
193	pub fn set_nonblocking(&self, non_blocking: bool) -> std::io::Result<()> {
194		self.set_option(libc::SOL_SOCKET, libc::O_NONBLOCK, bool_to_c_int(non_blocking))
195	}
196
197	/// Check if the socket in blocking or non-blocking mode.
198	pub fn get_nonblocking(&self) -> std::io::Result<bool> {
199		let raw: c_int = self.get_option(libc::SOL_SOCKET, libc::O_NONBLOCK)?;
200		Ok(raw != 0)
201	}
202
203	/// Gets the value of the SO_ERROR option on this socket.
204	///
205	/// This will retrieve the stored error in the underlying socket, clearing the field in the process.
206	/// This can be useful for checking errors between calls.
207	pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
208		let raw: c_int = self.get_option(libc::SOL_SOCKET, libc::SO_ERROR)?;
209		if raw == 0 {
210			Ok(None)
211		} else {
212			Ok(Some(std::io::Error::from_raw_os_error(raw)))
213		}
214	}
215
216	/// Get the local address the socket is bound to.
217	pub fn local_addr(&self) -> std::io::Result<Address> {
218		unsafe {
219			let mut address = std::mem::MaybeUninit::<Address>::zeroed();
220			let mut len = Address::max_len();
221			check_ret(libc::getsockname(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len))?;
222			Address::finalize(address, len)
223		}
224	}
225
226	/// Get the remote address the socket is connected to.
227	pub fn peer_addr(&self) -> std::io::Result<Address> {
228		unsafe {
229			let mut address = std::mem::MaybeUninit::<Address>::zeroed();
230			let mut len = Address::max_len();
231			check_ret(libc::getpeername(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len))?;
232			Address::finalize(address, len)
233		}
234	}
235
236	/// Connect the socket to a remote address.
237	///
238	/// It depends on the exact socket type what it means to connect the socket.
239	/// See `man connect` for more information.
240	pub fn connect(&self, address: &Address) -> std::io::Result<()> {
241		unsafe {
242			check_ret(libc::connect(self.as_raw_fd(), address.as_sockaddr(), address.len()))?;
243			Ok(())
244		}
245	}
246
247	/// Bind the socket to a local address.
248	///
249	/// It depends on the exact socket type what it means to bind the socket.
250	/// See `man bind` for more information.
251	pub fn bind(&self, address: &Address) -> std::io::Result<()> {
252		unsafe {
253			check_ret(libc::bind(self.as_raw_fd(), address.as_sockaddr(), address.len()))?;
254			Ok(())
255		}
256	}
257
258	/// Put the socket in listening mode, ready to accept connections.
259	///
260	/// Once the socket is in listening mode,
261	/// new connections can be accepted with [`accept()`](Socket::accept).
262	///
263	/// Not all socket types can be put into listening mode.
264	/// See `man listen` for more information.
265	pub fn listen(&self, backlog: c_int) -> std::io::Result<()> {
266		unsafe {
267			check_ret(libc::listen(self.as_raw_fd(), backlog))?;
268			Ok(())
269		}
270	}
271
272	/// Accept a new connection on the socket.
273	///
274	/// The socket must have been put in listening mode
275	/// with a call to [`listen()`](Socket::listen).
276	///
277	/// Not all socket types can be put into listening mode or accept connections.
278	/// See `man listen` for more information.
279	pub fn accept(&self) -> std::io::Result<(Self, Address)> {
280		unsafe {
281			let mut address = std::mem::MaybeUninit::zeroed();
282			let mut len = Address::max_len();
283			let fd = check_ret(libc::accept4(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len, libc::SOCK_CLOEXEC))?;
284			let socket = Self::wrap(FileDesc::from_raw_fd(fd))?;
285			let address = Address::finalize(address, len)?;
286			Ok((socket, address))
287		}
288	}
289
290	/// Send data over the socket to the connected peer.
291	///
292	/// Returns the number of transferred bytes, or an error.
293	///
294	/// See `man send` for more information.
295	pub fn send(&self, data: &[u8], flags: c_int) -> std::io::Result<usize> {
296		unsafe {
297			let data_ptr = data.as_ptr() as *const c_void;
298			let transferred = check_ret_isize(libc::send(self.as_raw_fd(), data_ptr, data.len(), flags | extra_flags::SENDMSG))?;
299			Ok(transferred as usize)
300		}
301	}
302
303	/// Send data over the socket to the specified address.
304	///
305	/// This function is only valid for connectionless protocols such as UDP or unix datagram sockets.
306	///
307	/// Returns the number of transferred bytes, or an error.
308	///
309	/// See `man sendto` for more information.
310	pub fn send_to(&self, data: &[u8], address: &Address, flags: c_int) -> std::io::Result<usize> {
311		unsafe {
312			let data_ptr = data.as_ptr() as *const c_void;
313			let transferred = check_ret_isize(libc::sendto(
314				self.as_raw_fd(),
315				data_ptr,
316				data.len(),
317				flags | extra_flags::SENDMSG,
318				address.as_sockaddr(), address.len()
319			))?;
320			Ok(transferred as usize)
321		}
322	}
323
324	/// Send a message over the socket to the connected peer.
325	///
326	/// Returns the number of transferred bytes, or an error.
327	///
328	/// See `man sendmsg` for more information.
329	pub fn send_msg(&self, data: &[IoSlice], cdata: Option<&[u8]>, flags: c_int) -> std::io::Result<usize> {
330		unsafe {
331			let mut header = std::mem::zeroed::<libc::msghdr>();
332			header.msg_iov = data.as_ptr() as *mut libc::iovec;
333			header.msg_iovlen = data.len();
334			header.msg_control = cdata.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *mut c_void;
335			header.msg_controllen = cdata.map(|x| x.len()).unwrap_or(0);
336
337			let ret = check_ret_isize(libc::sendmsg(self.as_raw_fd(), &header, flags | extra_flags::SENDMSG))?;
338			Ok(ret as usize)
339		}
340	}
341
342	/// Send a message over the socket to the specified address.
343	///
344	/// This function is only valid for connectionless protocols such as UDP or unix datagram sockets.
345	///
346	/// Returns the number of transferred bytes, or an error.
347	///
348	/// See `man sendmsg` for more information.
349	pub fn send_msg_to(&self, address: &Address, data: &[IoSlice], cdata: Option<&[u8]>, flags: c_int) -> std::io::Result<usize> {
350		unsafe {
351			let mut header = std::mem::zeroed::<libc::msghdr>();
352			header.msg_name = address.as_sockaddr() as *mut c_void;
353			header.msg_namelen = address.len();
354			header.msg_iov = data.as_ptr() as *mut libc::iovec;
355			header.msg_iovlen = data.len();
356			header.msg_control = cdata.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *mut c_void;
357			header.msg_controllen = cdata.map(|x| x.len()).unwrap_or(0);
358
359			let ret = check_ret_isize(libc::sendmsg(self.as_raw_fd(), &header, flags | extra_flags::SENDMSG))?;
360			Ok(ret as usize)
361		}
362	}
363
364	/// Receive a data on the socket from the connected peer.
365	///
366	/// Returns the number of transferred bytes, or an error.
367	///
368	/// See `man recv` for more information.
369	pub fn recv(&self, buffer: &mut [u8], flags: c_int) -> std::io::Result<usize> {
370		unsafe {
371			let buffer_ptr = buffer.as_mut_ptr() as *mut c_void;
372			let transferred = check_ret_isize(libc::recv(self.as_raw_fd(), buffer_ptr, buffer.len(), flags | extra_flags::RECVMSG))?;
373			Ok(transferred as usize)
374		}
375	}
376
377	/// Receive a data on the socket.
378	///
379	/// Returns the address of the sender and the number of transferred bytes, or an error.
380	///
381	/// See `man recvfrom` for more information.
382	pub fn recv_from(&self, buffer: &mut [u8], flags: c_int) -> std::io::Result<(Address, usize)> {
383		unsafe {
384			let buffer_ptr = buffer.as_mut_ptr() as *mut c_void;
385			let mut address = std::mem::MaybeUninit::zeroed();
386			let mut address_len = Address::max_len();
387			let transferred = check_ret_isize(libc::recvfrom(
388				self.as_raw_fd(),
389				buffer_ptr,
390				buffer.len(),
391				flags,
392				Address::as_sockaddr_mut(&mut address),
393				&mut address_len
394			))?;
395
396			let address = Address::finalize(address, address_len)?;
397			Ok((address, transferred as usize))
398		}
399	}
400
401	/// Receive a message on the socket from the connected peer.
402	///
403	/// If the call succeeds, the function returns a tuple with:
404	///   * the number of transferred bytes
405	///   * the number of transferred control message bytes
406	///   * the reception flags
407	///
408	/// See `man recvmsg` for more information.
409	pub fn recv_msg(&self, data: &[IoSliceMut], cdata: &mut SocketAncillary, flags: c_int) -> std::io::Result<(usize, c_int)> {
410		let (cdata_buf, cdata_len) = if cdata.capacity() == 0 {
411			(std::ptr::null_mut(), 0)
412		} else {
413			(cdata.buffer.as_mut_ptr(), cdata.capacity())
414		};
415
416		unsafe {
417			let mut header = std::mem::zeroed::<libc::msghdr>();
418			header.msg_iov = data.as_ptr() as *mut libc::iovec;
419			header.msg_iovlen = data.len();
420			header.msg_control = cdata_buf as *mut c_void;
421			header.msg_controllen = cdata_len;
422
423			let ret = check_ret_isize(libc::recvmsg(self.as_raw_fd(), &mut header, flags | extra_flags::RECVMSG))?;
424
425			cdata.length = header.msg_controllen as usize;
426			cdata.truncated = header.msg_flags & libc::MSG_CTRUNC != 0;
427			Ok((ret as usize, header.msg_flags))
428		}
429	}
430
431	/// Receive a message on the socket from any address.
432	///
433	/// If the call succeeds, the function returns a tuple with:
434	///   * the address of the sender
435	///   * the number of transferred bytes
436	///   * the number of transferred control message bytes
437	///   * the reception flags
438	///
439	/// See `man recvmsg` for more information.
440	pub fn recv_msg_from(&self, data: &[IoSliceMut], cdata: &mut SocketAncillary, flags: c_int) -> std::io::Result<(Address, usize, c_int)> {
441		let (cdata_buf, cdata_len) = if cdata.capacity() == 0 {
442			(std::ptr::null_mut(), 0)
443		} else {
444			(cdata.buffer.as_mut_ptr(), cdata.capacity())
445		};
446
447		unsafe {
448			let mut address = std::mem::MaybeUninit::zeroed();
449			let mut header = std::mem::zeroed::<libc::msghdr>();
450			header.msg_name = Address::as_sockaddr_mut(&mut address) as *mut c_void;
451			header.msg_namelen = Address::max_len();
452			header.msg_iov = data.as_ptr() as *mut libc::iovec;
453			header.msg_iovlen = data.len();
454			header.msg_control = cdata_buf as *mut c_void;
455			header.msg_controllen = cdata_len;
456
457			let ret = check_ret_isize(libc::recvmsg(self.as_raw_fd(), &mut header, flags | extra_flags::RECVMSG))?;
458			let address = Address::finalize(address, header.msg_namelen)?;
459			cdata.length = header.msg_controllen as usize;
460			cdata.truncated = header.msg_flags & libc::MSG_CTRUNC != 0;
461			Ok((address, ret as usize, header.msg_flags))
462		}
463	}
464}
465
466impl<Address: AsSocketAddress> FromRawFd for Socket<Address> {
467	unsafe fn from_raw_fd(fd: RawFd) -> Self {
468		Self::from_raw_fd(fd)
469	}
470}
471
472impl<Address: AsSocketAddress> AsRawFd for Socket<Address> {
473	fn as_raw_fd(&self) -> RawFd {
474		self.as_raw_fd()
475	}
476}
477
478impl<Address: AsSocketAddress> AsRawFd for &'_ Socket<Address> {
479	fn as_raw_fd(&self) -> RawFd {
480		(*self).as_raw_fd()
481	}
482}
483
484impl<Address: AsSocketAddress> IntoRawFd for Socket<Address> {
485	fn into_raw_fd(self) -> RawFd {
486		self.into_raw_fd()
487	}
488}
489
490/// Wrap the return value of a libc function in an [`std::io::Result`].
491///
492/// If the return value is -1, [`last_os_error()`](std::io::Error::last_os_error) is returned.
493/// Otherwise, the return value is returned wrapped as [`Ok`].
494fn check_ret(ret: c_int) -> std::io::Result<c_int> {
495	if ret == -1 {
496		Err(std::io::Error::last_os_error())
497	} else {
498		Ok(ret)
499	}
500}
501
502/// Wrap the return value of a libc function in an [`std::io::Result`].
503///
504/// If the return value is -1, [`last_os_error()`](std::io::Error::last_os_error) is returned.
505/// Otherwise, the return value is returned wrapped as [`Ok`].
506fn check_ret_isize(ret: isize) -> std::io::Result<isize> {
507	if ret == -1 {
508		Err(std::io::Error::last_os_error())
509	} else {
510		Ok(ret)
511	}
512}
513
514/// Create a socket and wrap the created file descriptor.
515fn socket(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<FileDesc> {
516	unsafe {
517		let fd = check_ret(libc::socket(domain, kind, protocol))?;
518		Ok(FileDesc::from_raw_fd(fd))
519	}
520}
521
522/// Create a socket pair and wrap the created file descriptors.
523fn socketpair(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<(FileDesc, FileDesc)> {
524	unsafe {
525		let mut fds = [0; 2];
526		check_ret(libc::socketpair(domain, kind, protocol, fds.as_mut_ptr()))?;
527		Ok((
528			FileDesc::from_raw_fd(fds[0]),
529			FileDesc::from_raw_fd(fds[1]),
530		))
531	}
532}
533
534fn bool_to_c_int(value: bool) -> c_int {
535	if value {
536		1
537	} else {
538		0
539	}
540}