1use std::{
2 future::Future,
3 io,
4 mem::{ManuallyDrop, MaybeUninit},
5};
6
7use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
8#[cfg(unix)]
9use compio_driver::op::CreateSocket;
10use compio_driver::{
11 AsRawFd, ToSharedFd, impl_raw_fd,
12 op::{
13 Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromVectored, RecvManaged,
14 RecvMsg, RecvResultExt, RecvVectored, ResultTakeBuffer, Send, SendMsg, SendTo,
15 SendToVectored, SendVectored, ShutdownSocket,
16 },
17 syscall,
18};
19use compio_runtime::{Attacher, BorrowedBuffer, BufferPool};
20use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type};
21
22use crate::PollFd;
23
24#[derive(Debug, Clone)]
25pub struct Socket {
26 pub(crate) socket: Attacher<Socket2>,
27}
28
29impl Socket {
30 pub fn from_socket2(socket: Socket2) -> io::Result<Self> {
31 Ok(Self {
32 socket: Attacher::new(socket)?,
33 })
34 }
35
36 pub fn peer_addr(&self) -> io::Result<SockAddr> {
37 self.socket.peer_addr()
38 }
39
40 pub fn local_addr(&self) -> io::Result<SockAddr> {
41 self.socket.local_addr()
42 }
43
44 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
45 PollFd::from_shared_fd(self.to_shared_fd())
46 }
47
48 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
49 PollFd::from_shared_fd(self.socket.into_inner())
50 }
51
52 #[cfg(windows)]
53 pub async fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> io::Result<Self> {
54 use std::panic::resume_unwind;
55
56 let socket = compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol))
57 .await
58 .unwrap_or_else(|e| resume_unwind(e))?;
59 Self::from_socket2(socket)
60 }
61
62 #[cfg(unix)]
63 pub async fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> io::Result<Self> {
64 use std::os::fd::FromRawFd;
65
66 let op = CreateSocket::new(
67 domain.into(),
68 ty.into(),
69 protocol.map(|p| p.into()).unwrap_or_default(),
70 );
71 let BufResult(res, _) = compio_runtime::submit(op).await;
72 let socket = unsafe { Socket2::from_raw_fd(res? as _) };
73
74 Self::from_socket2(socket)
75 }
76
77 pub async fn bind(addr: &SockAddr, ty: Type, protocol: Option<Protocol>) -> io::Result<Self> {
78 let socket = Self::new(addr.domain(), ty, protocol).await?;
79 socket.socket.bind(addr)?;
80 Ok(socket)
81 }
82
83 pub fn listen(&self, backlog: i32) -> io::Result<()> {
84 self.socket.listen(backlog)
85 }
86
87 pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
88 self.socket.connect(addr)
89 }
90
91 pub async fn connect_async(&self, addr: &SockAddr) -> io::Result<()> {
92 let op = Connect::new(self.to_shared_fd(), addr.clone());
93 let BufResult(res, _op) = compio_runtime::submit(op).await;
94 #[cfg(windows)]
95 {
96 res?;
97 _op.update_context()?;
98 Ok(())
99 }
100 #[cfg(unix)]
101 {
102 res.map(|_| ())
103 }
104 }
105
106 #[cfg(unix)]
107 pub async fn accept(&self) -> io::Result<(Self, SockAddr)> {
108 use std::os::fd::FromRawFd;
109
110 let op = Accept::new(self.to_shared_fd());
111 let BufResult(res, op) = compio_runtime::submit(op).await;
112 let addr = op.into_addr();
113 let accept_sock = unsafe { Socket2::from_raw_fd(res? as _) };
114 let accept_sock = Self::from_socket2(accept_sock)?;
115 Ok((accept_sock, addr))
116 }
117
118 #[cfg(windows)]
119 pub async fn accept(&self) -> io::Result<(Self, SockAddr)> {
120 use std::panic::resume_unwind;
121
122 let domain = self.local_addr()?.domain();
123 let ty = self.socket.r#type()?;
125 let protocol = self.socket.protocol()?;
126 let accept_sock =
127 compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol))
128 .await
129 .unwrap_or_else(|e| resume_unwind(e))?;
130 let op = Accept::new(self.to_shared_fd(), accept_sock);
131 let BufResult(res, op) = compio_runtime::submit(op).await;
132 res?;
133 op.update_context()?;
134 let (accept_sock, addr) = op.into_addr()?;
135 Ok((Self::from_socket2(accept_sock)?, addr))
136 }
137
138 pub fn close(self) -> impl Future<Output = io::Result<()>> {
139 let this = ManuallyDrop::new(self);
143 async move {
144 let fd = ManuallyDrop::into_inner(this)
145 .socket
146 .into_inner()
147 .take()
148 .await;
149 if let Some(fd) = fd {
150 let op = CloseSocket::new(fd.into());
151 compio_runtime::submit(op).await.0?;
152 }
153 Ok(())
154 }
155 }
156
157 pub async fn shutdown(&self) -> io::Result<()> {
158 let op = ShutdownSocket::new(self.to_shared_fd(), std::net::Shutdown::Write);
159 compio_runtime::submit(op).await.0?;
160 Ok(())
161 }
162
163 pub async fn recv<B: IoBufMut>(&self, buffer: B) -> BufResult<usize, B> {
164 let fd = self.to_shared_fd();
165 let op = Recv::new(fd, buffer);
166 compio_runtime::submit(op).await.into_inner().map_advanced()
167 }
168
169 pub async fn recv_vectored<V: IoVectoredBufMut>(&self, buffer: V) -> BufResult<usize, V> {
170 let fd = self.to_shared_fd();
171 let op = RecvVectored::new(fd, buffer);
172 compio_runtime::submit(op).await.into_inner().map_advanced()
173 }
174
175 pub async fn recv_managed<'a>(
176 &self,
177 buffer_pool: &'a BufferPool,
178 len: usize,
179 ) -> io::Result<BorrowedBuffer<'a>> {
180 let fd = self.to_shared_fd();
181 let buffer_pool = buffer_pool.try_inner()?;
182 let op = RecvManaged::new(fd, buffer_pool, len)?;
183 compio_runtime::submit_with_flags(op)
184 .await
185 .take_buffer(buffer_pool)
186 }
187
188 pub async fn send<T: IoBuf>(&self, buffer: T) -> BufResult<usize, T> {
189 let fd = self.to_shared_fd();
190 let op = Send::new(fd, buffer);
191 compio_runtime::submit(op).await.into_inner()
192 }
193
194 pub async fn send_vectored<T: IoVectoredBuf>(&self, buffer: T) -> BufResult<usize, T> {
195 let fd = self.to_shared_fd();
196 let op = SendVectored::new(fd, buffer);
197 compio_runtime::submit(op).await.into_inner()
198 }
199
200 pub async fn recv_from<T: IoBufMut>(&self, buffer: T) -> BufResult<(usize, SockAddr), T> {
201 let fd = self.to_shared_fd();
202 let op = RecvFrom::new(fd, buffer);
203 compio_runtime::submit(op)
204 .await
205 .into_inner()
206 .map_addr()
207 .map_advanced()
208 }
209
210 pub async fn recv_from_vectored<T: IoVectoredBufMut>(
211 &self,
212 buffer: T,
213 ) -> BufResult<(usize, SockAddr), T> {
214 let fd = self.to_shared_fd();
215 let op = RecvFromVectored::new(fd, buffer);
216 compio_runtime::submit(op)
217 .await
218 .into_inner()
219 .map_addr()
220 .map_advanced()
221 }
222
223 pub async fn recv_msg<T: IoBufMut, C: IoBufMut>(
224 &self,
225 buffer: T,
226 control: C,
227 ) -> BufResult<(usize, usize, SockAddr), (T, C)> {
228 self.recv_msg_vectored([buffer], control)
229 .await
230 .map_buffer(|([buffer], control)| (buffer, control))
231 }
232
233 pub async fn recv_msg_vectored<T: IoVectoredBufMut, C: IoBufMut>(
234 &self,
235 buffer: T,
236 control: C,
237 ) -> BufResult<(usize, usize, SockAddr), (T, C)> {
238 let fd = self.to_shared_fd();
239 let op = RecvMsg::new(fd, buffer, control);
240 compio_runtime::submit(op)
241 .await
242 .into_inner()
243 .map_addr()
244 .map_advanced()
245 }
246
247 pub async fn send_to<T: IoBuf>(&self, buffer: T, addr: &SockAddr) -> BufResult<usize, T> {
248 let fd = self.to_shared_fd();
249 let op = SendTo::new(fd, buffer, addr.clone());
250 compio_runtime::submit(op).await.into_inner()
251 }
252
253 pub async fn send_to_vectored<T: IoVectoredBuf>(
254 &self,
255 buffer: T,
256 addr: &SockAddr,
257 ) -> BufResult<usize, T> {
258 let fd = self.to_shared_fd();
259 let op = SendToVectored::new(fd, buffer, addr.clone());
260 compio_runtime::submit(op).await.into_inner()
261 }
262
263 pub async fn send_msg<T: IoBuf, C: IoBuf>(
264 &self,
265 buffer: T,
266 control: C,
267 addr: &SockAddr,
268 ) -> BufResult<usize, (T, C)> {
269 self.send_msg_vectored([buffer], control, addr)
270 .await
271 .map_buffer(|([buffer], control)| (buffer, control))
272 }
273
274 pub async fn send_msg_vectored<T: IoVectoredBuf, C: IoBuf>(
275 &self,
276 buffer: T,
277 control: C,
278 addr: &SockAddr,
279 ) -> BufResult<usize, (T, C)> {
280 let fd = self.to_shared_fd();
281 let op = SendMsg::new(fd, buffer, control, addr.clone());
282 compio_runtime::submit(op).await.into_inner()
283 }
284
285 #[cfg(unix)]
286 pub unsafe fn get_socket_option<T: Copy>(&self, level: i32, name: i32) -> io::Result<T> {
287 let mut value: MaybeUninit<T> = MaybeUninit::uninit();
288 let mut len = size_of::<T>() as libc::socklen_t;
289 syscall!(libc::getsockopt(
290 self.socket.as_raw_fd(),
291 level,
292 name,
293 value.as_mut_ptr() as _,
294 &mut len
295 ))
296 .map(|_| {
297 debug_assert_eq!(len as usize, size_of::<T>());
298 value.assume_init()
300 })
301 }
302
303 #[cfg(windows)]
304 pub unsafe fn get_socket_option<T: Copy>(&self, level: i32, name: i32) -> io::Result<T> {
305 let mut value: MaybeUninit<T> = MaybeUninit::uninit();
306 let mut len = size_of::<T>() as i32;
307 syscall!(
308 SOCKET,
309 windows_sys::Win32::Networking::WinSock::getsockopt(
310 self.socket.as_raw_fd() as _,
311 level,
312 name,
313 value.as_mut_ptr() as _,
314 &mut len
315 )
316 )
317 .map(|_| {
318 debug_assert_eq!(len as usize, size_of::<T>());
319 value.assume_init()
321 })
322 }
323
324 #[cfg(unix)]
325 pub unsafe fn set_socket_option<T: Copy>(
326 &self,
327 level: i32,
328 name: i32,
329 value: &T,
330 ) -> io::Result<()> {
331 syscall!(libc::setsockopt(
332 self.socket.as_raw_fd(),
333 level,
334 name,
335 value as *const _ as _,
336 std::mem::size_of::<T>() as _
337 ))
338 .map(|_| ())
339 }
340
341 #[cfg(windows)]
342 pub unsafe fn set_socket_option<T: Copy>(
343 &self,
344 level: i32,
345 name: i32,
346 value: &T,
347 ) -> io::Result<()> {
348 syscall!(
349 SOCKET,
350 windows_sys::Win32::Networking::WinSock::setsockopt(
351 self.socket.as_raw_fd() as _,
352 level,
353 name,
354 value as *const _ as _,
355 std::mem::size_of::<T>() as _
356 )
357 )
358 .map(|_| ())
359 }
360}
361
362impl_raw_fd!(Socket, Socket2, socket, socket);