compio_net/
socket.rs

1use std::{
2    future::Future,
3    io,
4    mem::{ManuallyDrop, MaybeUninit},
5};
6
7use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
8#[cfg(unix)]
9use compio_driver::op::CreateSocket;
10use compio_driver::{
11    AsRawFd, ToSharedFd, impl_raw_fd,
12    op::{
13        Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromVectored, RecvManaged,
14        RecvMsg, RecvResultExt, RecvVectored, ResultTakeBuffer, Send, SendMsg, SendTo,
15        SendToVectored, SendVectored, ShutdownSocket,
16    },
17    syscall,
18};
19use compio_runtime::{Attacher, BorrowedBuffer, BufferPool};
20use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type};
21
22use crate::PollFd;
23
24#[derive(Debug, Clone)]
25pub struct Socket {
26    pub(crate) socket: Attacher<Socket2>,
27}
28
29impl Socket {
30    pub fn from_socket2(socket: Socket2) -> io::Result<Self> {
31        Ok(Self {
32            socket: Attacher::new(socket)?,
33        })
34    }
35
36    pub fn peer_addr(&self) -> io::Result<SockAddr> {
37        self.socket.peer_addr()
38    }
39
40    pub fn local_addr(&self) -> io::Result<SockAddr> {
41        self.socket.local_addr()
42    }
43
44    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
45        PollFd::from_shared_fd(self.to_shared_fd())
46    }
47
48    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
49        PollFd::from_shared_fd(self.socket.into_inner())
50    }
51
52    #[cfg(windows)]
53    pub async fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> io::Result<Self> {
54        use std::panic::resume_unwind;
55
56        let socket = compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol))
57            .await
58            .unwrap_or_else(|e| resume_unwind(e))?;
59        Self::from_socket2(socket)
60    }
61
62    #[cfg(unix)]
63    pub async fn new(domain: Domain, ty: Type, protocol: Option<Protocol>) -> io::Result<Self> {
64        use std::os::fd::FromRawFd;
65
66        let op = CreateSocket::new(
67            domain.into(),
68            ty.into(),
69            protocol.map(|p| p.into()).unwrap_or_default(),
70        );
71        let BufResult(res, _) = compio_runtime::submit(op).await;
72        let socket = unsafe { Socket2::from_raw_fd(res? as _) };
73
74        Self::from_socket2(socket)
75    }
76
77    pub async fn bind(addr: &SockAddr, ty: Type, protocol: Option<Protocol>) -> io::Result<Self> {
78        let socket = Self::new(addr.domain(), ty, protocol).await?;
79        socket.socket.bind(addr)?;
80        Ok(socket)
81    }
82
83    pub fn listen(&self, backlog: i32) -> io::Result<()> {
84        self.socket.listen(backlog)
85    }
86
87    pub fn connect(&self, addr: &SockAddr) -> io::Result<()> {
88        self.socket.connect(addr)
89    }
90
91    pub async fn connect_async(&self, addr: &SockAddr) -> io::Result<()> {
92        let op = Connect::new(self.to_shared_fd(), addr.clone());
93        let BufResult(res, _op) = compio_runtime::submit(op).await;
94        #[cfg(windows)]
95        {
96            res?;
97            _op.update_context()?;
98            Ok(())
99        }
100        #[cfg(unix)]
101        {
102            res.map(|_| ())
103        }
104    }
105
106    #[cfg(unix)]
107    pub async fn accept(&self) -> io::Result<(Self, SockAddr)> {
108        use std::os::fd::FromRawFd;
109
110        let op = Accept::new(self.to_shared_fd());
111        let BufResult(res, op) = compio_runtime::submit(op).await;
112        let addr = op.into_addr();
113        let accept_sock = unsafe { Socket2::from_raw_fd(res? as _) };
114        let accept_sock = Self::from_socket2(accept_sock)?;
115        Ok((accept_sock, addr))
116    }
117
118    #[cfg(windows)]
119    pub async fn accept(&self) -> io::Result<(Self, SockAddr)> {
120        use std::panic::resume_unwind;
121
122        let domain = self.local_addr()?.domain();
123        // We should allow users sending this accepted socket to a new thread.
124        let ty = self.socket.r#type()?;
125        let protocol = self.socket.protocol()?;
126        let accept_sock =
127            compio_runtime::spawn_blocking(move || Socket2::new(domain, ty, protocol))
128                .await
129                .unwrap_or_else(|e| resume_unwind(e))?;
130        let op = Accept::new(self.to_shared_fd(), accept_sock);
131        let BufResult(res, op) = compio_runtime::submit(op).await;
132        res?;
133        op.update_context()?;
134        let (accept_sock, addr) = op.into_addr()?;
135        Ok((Self::from_socket2(accept_sock)?, addr))
136    }
137
138    pub fn close(self) -> impl Future<Output = io::Result<()>> {
139        // Make sure that self won't be dropped after `close` called.
140        // Users may call this method and drop the future immediately. In that way the
141        // `close` should be cancelled.
142        let this = ManuallyDrop::new(self);
143        async move {
144            let fd = ManuallyDrop::into_inner(this)
145                .socket
146                .into_inner()
147                .take()
148                .await;
149            if let Some(fd) = fd {
150                let op = CloseSocket::new(fd.into());
151                compio_runtime::submit(op).await.0?;
152            }
153            Ok(())
154        }
155    }
156
157    pub async fn shutdown(&self) -> io::Result<()> {
158        let op = ShutdownSocket::new(self.to_shared_fd(), std::net::Shutdown::Write);
159        compio_runtime::submit(op).await.0?;
160        Ok(())
161    }
162
163    pub async fn recv<B: IoBufMut>(&self, buffer: B) -> BufResult<usize, B> {
164        let fd = self.to_shared_fd();
165        let op = Recv::new(fd, buffer);
166        compio_runtime::submit(op).await.into_inner().map_advanced()
167    }
168
169    pub async fn recv_vectored<V: IoVectoredBufMut>(&self, buffer: V) -> BufResult<usize, V> {
170        let fd = self.to_shared_fd();
171        let op = RecvVectored::new(fd, buffer);
172        compio_runtime::submit(op).await.into_inner().map_advanced()
173    }
174
175    pub async fn recv_managed<'a>(
176        &self,
177        buffer_pool: &'a BufferPool,
178        len: usize,
179    ) -> io::Result<BorrowedBuffer<'a>> {
180        let fd = self.to_shared_fd();
181        let buffer_pool = buffer_pool.try_inner()?;
182        let op = RecvManaged::new(fd, buffer_pool, len)?;
183        compio_runtime::submit_with_flags(op)
184            .await
185            .take_buffer(buffer_pool)
186    }
187
188    pub async fn send<T: IoBuf>(&self, buffer: T) -> BufResult<usize, T> {
189        let fd = self.to_shared_fd();
190        let op = Send::new(fd, buffer);
191        compio_runtime::submit(op).await.into_inner()
192    }
193
194    pub async fn send_vectored<T: IoVectoredBuf>(&self, buffer: T) -> BufResult<usize, T> {
195        let fd = self.to_shared_fd();
196        let op = SendVectored::new(fd, buffer);
197        compio_runtime::submit(op).await.into_inner()
198    }
199
200    pub async fn recv_from<T: IoBufMut>(&self, buffer: T) -> BufResult<(usize, SockAddr), T> {
201        let fd = self.to_shared_fd();
202        let op = RecvFrom::new(fd, buffer);
203        compio_runtime::submit(op)
204            .await
205            .into_inner()
206            .map_addr()
207            .map_advanced()
208    }
209
210    pub async fn recv_from_vectored<T: IoVectoredBufMut>(
211        &self,
212        buffer: T,
213    ) -> BufResult<(usize, SockAddr), T> {
214        let fd = self.to_shared_fd();
215        let op = RecvFromVectored::new(fd, buffer);
216        compio_runtime::submit(op)
217            .await
218            .into_inner()
219            .map_addr()
220            .map_advanced()
221    }
222
223    pub async fn recv_msg<T: IoBufMut, C: IoBufMut>(
224        &self,
225        buffer: T,
226        control: C,
227    ) -> BufResult<(usize, usize, SockAddr), (T, C)> {
228        self.recv_msg_vectored([buffer], control)
229            .await
230            .map_buffer(|([buffer], control)| (buffer, control))
231    }
232
233    pub async fn recv_msg_vectored<T: IoVectoredBufMut, C: IoBufMut>(
234        &self,
235        buffer: T,
236        control: C,
237    ) -> BufResult<(usize, usize, SockAddr), (T, C)> {
238        let fd = self.to_shared_fd();
239        let op = RecvMsg::new(fd, buffer, control);
240        compio_runtime::submit(op)
241            .await
242            .into_inner()
243            .map_addr()
244            .map_advanced()
245    }
246
247    pub async fn send_to<T: IoBuf>(&self, buffer: T, addr: &SockAddr) -> BufResult<usize, T> {
248        let fd = self.to_shared_fd();
249        let op = SendTo::new(fd, buffer, addr.clone());
250        compio_runtime::submit(op).await.into_inner()
251    }
252
253    pub async fn send_to_vectored<T: IoVectoredBuf>(
254        &self,
255        buffer: T,
256        addr: &SockAddr,
257    ) -> BufResult<usize, T> {
258        let fd = self.to_shared_fd();
259        let op = SendToVectored::new(fd, buffer, addr.clone());
260        compio_runtime::submit(op).await.into_inner()
261    }
262
263    pub async fn send_msg<T: IoBuf, C: IoBuf>(
264        &self,
265        buffer: T,
266        control: C,
267        addr: &SockAddr,
268    ) -> BufResult<usize, (T, C)> {
269        self.send_msg_vectored([buffer], control, addr)
270            .await
271            .map_buffer(|([buffer], control)| (buffer, control))
272    }
273
274    pub async fn send_msg_vectored<T: IoVectoredBuf, C: IoBuf>(
275        &self,
276        buffer: T,
277        control: C,
278        addr: &SockAddr,
279    ) -> BufResult<usize, (T, C)> {
280        let fd = self.to_shared_fd();
281        let op = SendMsg::new(fd, buffer, control, addr.clone());
282        compio_runtime::submit(op).await.into_inner()
283    }
284
285    #[cfg(unix)]
286    pub unsafe fn get_socket_option<T: Copy>(&self, level: i32, name: i32) -> io::Result<T> {
287        let mut value: MaybeUninit<T> = MaybeUninit::uninit();
288        let mut len = size_of::<T>() as libc::socklen_t;
289        syscall!(libc::getsockopt(
290            self.socket.as_raw_fd(),
291            level,
292            name,
293            value.as_mut_ptr() as _,
294            &mut len
295        ))
296        .map(|_| {
297            debug_assert_eq!(len as usize, size_of::<T>());
298            // SAFETY: The value is initialized by `getsockopt`.
299            value.assume_init()
300        })
301    }
302
303    #[cfg(windows)]
304    pub unsafe fn get_socket_option<T: Copy>(&self, level: i32, name: i32) -> io::Result<T> {
305        let mut value: MaybeUninit<T> = MaybeUninit::uninit();
306        let mut len = size_of::<T>() as i32;
307        syscall!(
308            SOCKET,
309            windows_sys::Win32::Networking::WinSock::getsockopt(
310                self.socket.as_raw_fd() as _,
311                level,
312                name,
313                value.as_mut_ptr() as _,
314                &mut len
315            )
316        )
317        .map(|_| {
318            debug_assert_eq!(len as usize, size_of::<T>());
319            // SAFETY: The value is initialized by `getsockopt`.
320            value.assume_init()
321        })
322    }
323
324    #[cfg(unix)]
325    pub unsafe fn set_socket_option<T: Copy>(
326        &self,
327        level: i32,
328        name: i32,
329        value: &T,
330    ) -> io::Result<()> {
331        syscall!(libc::setsockopt(
332            self.socket.as_raw_fd(),
333            level,
334            name,
335            value as *const _ as _,
336            std::mem::size_of::<T>() as _
337        ))
338        .map(|_| ())
339    }
340
341    #[cfg(windows)]
342    pub unsafe fn set_socket_option<T: Copy>(
343        &self,
344        level: i32,
345        name: i32,
346        value: &T,
347    ) -> io::Result<()> {
348        syscall!(
349            SOCKET,
350            windows_sys::Win32::Networking::WinSock::setsockopt(
351                self.socket.as_raw_fd() as _,
352                level,
353                name,
354                value as *const _ as _,
355                std::mem::size_of::<T>() as _
356            )
357        )
358        .map(|_| ())
359    }
360}
361
362impl_raw_fd!(Socket, Socket2, socket, socket);