Skip to main content

compio_driver/sys/op/socket/
unix.rs

1use std::{net::Shutdown, num::NonZeroU32};
2
3use rustix::{
4    io::close,
5    net::{
6        AddressFamily, Protocol, RecvAncillaryBuffer, SendAncillaryBuffer, SocketAddrAny,
7        SocketType, acceptfrom_with, bind, connect, listen, recv, recvfrom, recvmsg, send, sendmsg,
8        sendmsg_addr, sendto, shutdown, socket_with,
9    },
10};
11
12use crate::sys::op::*;
13
14impl<S: AsFd> Accept<S> {
15    pub(crate) fn call(&mut self) -> io::Result<usize> {
16        let (owned, addr) = acceptfrom_with(self.fd.as_fd(), SOCKET_FLAG)?;
17        let fd = owned.as_raw_fd();
18        let socket: Socket2 = owned.into();
19
20        if cfg!(apple) {
21            socket.set_cloexec(true)?;
22            socket.set_nonblocking(true)?;
23        }
24
25        copy_addr_from(&mut self.buffer, &mut self.addr_len, addr);
26        self.accepted_fd = Some(socket);
27
28        Ok(fd as usize)
29    }
30}
31
32impl<S: AsFd> Connect<S> {
33    pub(crate) fn call(&self) -> io::Result<usize> {
34        connect(&self.fd, &SockAddrArg(&self.addr))?;
35        Ok(0)
36    }
37}
38
39impl<T: IoBuf, S: AsFd> Send<T, S> {
40    pub(crate) fn call(&mut self) -> io::Result<usize> {
41        send(self.fd.as_fd(), self.buffer.as_init(), self.flags).map_err(Into::into)
42    }
43}
44
45impl<T: IoBuf, S: AsFd> SendTo<T, S> {
46    pub(crate) fn call(&self) -> io::Result<usize> {
47        sendto(
48            self.header.fd.as_fd(),
49            self.buffer.as_init(),
50            self.header.flags,
51            &SockAddrArg(&self.header.addr),
52        )
53        .map_err(Into::into)
54    }
55}
56
57impl<T: IoVectoredBuf, S: AsFd> SendVectored<T, S> {
58    pub(crate) fn call(&self, control: &mut SendVectoredControl) -> io::Result<usize> {
59        let mut anc = SendAncillaryBuffer::default();
60
61        sendmsg(
62            self.fd.as_fd(),
63            io_slice(&control.slices),
64            &mut anc,
65            self.flags,
66        )
67        .map_err(Into::into)
68    }
69}
70
71impl<T: IoVectoredBuf, S: AsFd> SendToVectored<T, S> {
72    pub(crate) fn call(&mut self, control: &mut SendMsgControl) -> io::Result<usize> {
73        let addr = SockAddrArg(&self.header.addr);
74        let mut anc = SendAncillaryBuffer::default();
75        let buf = io_slice(&control.slices);
76
77        sendmsg_addr(
78            self.header.fd.as_fd(),
79            &addr,
80            buf,
81            &mut anc,
82            self.header.flags,
83        )
84        .map_err(Into::into)
85    }
86}
87
88impl<T: IoVectoredBuf, C: IoBuf, S: AsFd> SendMsg<T, C, S> {
89    pub(crate) fn call(&mut self, control: &mut SendMsgControl) -> io::Result<usize> {
90        // Both rustix and nix expose api that uses structured AncillaryBuffer
91        // building, no way to just throw in an ancillary buf. Fallback to libc here.
92        syscall!(libc::sendmsg(
93            self.fd.as_fd().as_raw_fd(),
94            &control.msg,
95            self.flags.bits() as _,
96        ))
97    }
98}
99
100impl<T: IoBufMut, S: AsFd> Recv<T, S> {
101    pub(crate) fn call(&mut self) -> io::Result<usize> {
102        let (_, len) = recv(self.fd.as_fd(), self.buffer.as_uninit(), self.flags)?;
103
104        Ok(len)
105    }
106}
107
108impl<T: IoVectoredBufMut, S: AsFd> RecvVectored<T, S> {
109    pub(crate) fn call(&mut self, control: &mut RecvVectoredControl) -> io::Result<usize> {
110        let res = recvmsg(
111            self.fd.as_fd(),
112            io_slice_mut(&mut control.slices),
113            &mut RecvAncillaryBuffer::default(),
114            self.flags,
115        )?;
116
117        // Kernel may truncate and return a larger-than-buffer size
118        Ok(res.bytes.min(self.buffer.total_capacity()))
119    }
120}
121
122impl<S: AsFd> RecvFromHeader<S> {
123    pub fn set_addr(&mut self, addr: Option<SocketAddrAny>) {
124        copy_addr_from(&mut self.addr, &mut self.addr_len, addr)
125    }
126}
127
128impl<T: IoBufMut, S: AsFd> RecvFrom<T, S> {
129    pub(crate) fn call(&mut self) -> io::Result<usize> {
130        let (_, len, addr) = recvfrom(&self.header.fd, self.buffer.as_uninit(), self.header.flags)?;
131
132        self.header.set_addr(addr);
133
134        Ok(len.min(self.buffer.buf_capacity()))
135    }
136}
137
138impl<T: IoVectoredBufMut, C: IoBufMut, S: AsFd> RecvMsg<T, C, S> {
139    pub(crate) fn call(&mut self, control: &mut RecvMsgControl) -> io::Result<usize> {
140        let res = syscall!(libc::recvmsg(
141            self.header.fd.as_fd().as_raw_fd(),
142            &raw mut control.msg,
143            self.header.flags.bits() as _,
144        ))?;
145
146        self.update_control(control);
147
148        Ok(res)
149    }
150}
151
152impl<T: IoVectoredBufMut, S: AsFd> RecvFromVectored<T, S> {
153    pub(crate) fn call(&mut self, control: &mut RecvMsgControl) -> io::Result<usize> {
154        let res = recvmsg(
155            &self.header.fd,
156            io_slice_mut(&mut control.slices),
157            &mut RecvAncillaryBuffer::default(),
158            self.header.flags,
159        )?;
160
161        self.header.set_addr(res.address);
162
163        Ok(res.bytes)
164    }
165}
166
167/// Create a socket.
168pub struct CreateSocket {
169    pub(crate) domain: AddressFamily,
170    pub(crate) socket_type: SocketType,
171    pub(crate) protocol: Option<Protocol>,
172    pub(crate) opened_fd: Option<Socket2>,
173}
174
175impl CreateSocket {
176    /// Create [`CreateSocket`].
177    pub fn new(domain: i32, socket_type: i32, protocol: i32) -> Self {
178        let domain = AddressFamily::from_raw(domain as _);
179        let socket_type = SocketType::from_raw(socket_type as _);
180        let protocol = NonZeroU32::new(protocol as _).map(Protocol::from_raw);
181
182        Self {
183            domain,
184            socket_type,
185            protocol,
186            opened_fd: None,
187        }
188    }
189
190    pub(crate) fn call(&mut self) -> io::Result<usize> {
191        let owned = socket_with(self.domain, self.socket_type, SOCKET_FLAG, self.protocol)?;
192        let fd = owned.as_raw_fd();
193        let socket: Socket2 = owned.into();
194
195        #[cfg(apple)]
196        {
197            socket.set_cloexec(true)?;
198            socket.set_nosigpipe(true)?;
199            socket.set_nonblocking(true)?;
200        }
201
202        self.opened_fd = Some(socket);
203        Ok(fd as _)
204    }
205}
206
207impl IntoInner for CreateSocket {
208    type Inner = Socket2;
209
210    fn into_inner(self) -> Self::Inner {
211        self.opened_fd.expect("socket not created")
212    }
213}
214
215/// Bind a socket to an address.
216pub struct Bind<S> {
217    pub(crate) fd: S,
218    pub(crate) addr: SockAddr,
219}
220
221impl<S> Bind<S> {
222    /// Create [`Bind`].
223    pub fn new(fd: S, addr: SockAddr) -> Self {
224        Self { fd, addr }
225    }
226}
227
228impl<S: AsFd> Bind<S> {
229    pub(crate) fn call(&self) -> io::Result<usize> {
230        bind(self.fd.as_fd(), &SockAddrArg(&self.addr))?;
231        Ok(0)
232    }
233}
234
235/// Listen for connections on a socket.
236pub struct Listen<S> {
237    pub(crate) fd: S,
238    pub(crate) backlog: i32,
239}
240
241impl<S> Listen<S> {
242    /// Create [`Listen`].
243    pub fn new(fd: S, backlog: i32) -> Self {
244        Self { fd, backlog }
245    }
246}
247
248impl<S: AsFd> Listen<S> {
249    pub(crate) fn call(&self) -> io::Result<usize> {
250        listen(self.fd.as_fd(), self.backlog)?;
251        Ok(0)
252    }
253}
254
255/// Shutdown a socket.
256pub struct ShutdownSocket<S> {
257    pub(crate) fd: S,
258    pub(crate) how: Shutdown,
259}
260
261impl<S> ShutdownSocket<S> {
262    /// Create [`ShutdownSocket`].
263    pub fn new(fd: S, how: Shutdown) -> Self {
264        Self { fd, how }
265    }
266}
267
268impl<S: AsFd> ShutdownSocket<S> {
269    #[cfg(io_uring)]
270    pub(crate) fn how(&self) -> i32 {
271        match self.how {
272            Shutdown::Write => libc::SHUT_WR,
273            Shutdown::Read => libc::SHUT_RD,
274            Shutdown::Both => libc::SHUT_RDWR,
275        }
276    }
277
278    pub(crate) fn call(&mut self) -> io::Result<usize> {
279        let how = match self.how {
280            Shutdown::Write => rustix::net::Shutdown::Write,
281            Shutdown::Read => rustix::net::Shutdown::Read,
282            Shutdown::Both => rustix::net::Shutdown::Both,
283        };
284        shutdown(&self.fd, how)?;
285        Ok(0)
286    }
287}
288
289impl CloseSocket {
290    pub(crate) fn call(&mut self) -> io::Result<usize> {
291        unsafe { close(self.fd.as_raw_fd()) };
292        Ok(0)
293    }
294}
295
296/// Accept a connection.
297pub struct Accept<S> {
298    pub(crate) fd: S,
299    pub(crate) buffer: SockAddrStorage,
300    pub(crate) addr_len: socklen_t,
301    pub(crate) accepted_fd: Option<Socket2>,
302}
303
304impl<S> Accept<S> {
305    /// Create [`Accept`].
306    pub fn new(fd: S) -> Self {
307        let buffer = SockAddrStorage::zeroed();
308        let addr_len = buffer.size_of();
309        Self {
310            fd,
311            buffer,
312            addr_len,
313            accepted_fd: None,
314        }
315    }
316}
317
318impl<S> IntoInner for Accept<S> {
319    type Inner = (Socket2, SockAddr);
320
321    fn into_inner(mut self) -> Self::Inner {
322        let socket = self.accepted_fd.take().expect("socket not accepted");
323        (socket, unsafe { SockAddr::new(self.buffer, self.addr_len) })
324    }
325}
326
327#[doc(hidden)]
328pub struct RecvVectoredControl {
329    pub(crate) msg: libc::msghdr,
330    #[allow(dead_code)]
331    pub(crate) slices: Vec<SysSlice>,
332}
333
334impl Default for RecvVectoredControl {
335    fn default() -> Self {
336        Self {
337            msg: unsafe { std::mem::zeroed() },
338            slices: Vec::new(),
339        }
340    }
341}
342
343impl<T: IoVectoredBufMut, S> RecvVectored<T, S> {
344    pub(crate) fn init_control(&mut self, ctrl: &mut RecvVectoredControl) {
345        ctrl.slices = self.buffer.sys_slices_mut();
346        ctrl.msg.msg_iov = ctrl.slices.as_mut_ptr() as _;
347        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
348    }
349}
350
351#[doc(hidden)]
352pub struct SendVectoredControl {
353    pub(crate) msg: libc::msghdr,
354    #[allow(dead_code)]
355    pub(crate) slices: Vec<SysSlice>,
356}
357
358impl Default for SendVectoredControl {
359    fn default() -> Self {
360        Self {
361            msg: unsafe { std::mem::zeroed() },
362            slices: Vec::new(),
363        }
364    }
365}
366
367impl<T: IoVectoredBuf, S> SendVectored<T, S> {
368    pub(crate) fn init_control(&mut self, ctrl: &mut SendVectoredControl) {
369        ctrl.slices = self.buffer.sys_slices();
370        ctrl.msg.msg_iov = ctrl.slices.as_ptr() as _;
371        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
372    }
373}
374
375#[doc(hidden)]
376pub struct SendMsgControl {
377    pub(crate) msg: libc::msghdr,
378    #[allow(dead_code)]
379    pub(crate) slices: Multi<SysSlice>,
380}
381
382impl<S: AsFd> SendToHeader<S> {
383    #[allow(dead_code)]
384    pub(crate) fn create_control(
385        &mut self,
386        ctrl: &mut SendMsgControl,
387        slices: impl Into<Multi<SysSlice>>,
388    ) {
389        ctrl.msg.msg_name = self.addr.as_ptr() as _;
390        ctrl.msg.msg_namelen = self.addr.len();
391        ctrl.slices = slices.into();
392        ctrl.msg.msg_iov = ctrl.slices.as_mut_ptr() as _;
393        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
394    }
395}
396
397impl Default for SendMsgControl {
398    fn default() -> Self {
399        Self {
400            msg: unsafe { std::mem::zeroed() },
401            slices: Multi::new(),
402        }
403    }
404}
405
406impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
407    pub(crate) fn init_control(&mut self, ctrl: &mut SendMsgControl) {
408        ctrl.slices = self.buffer.sys_slices().into();
409        match self.addr.as_ref() {
410            Some(addr) => {
411                ctrl.msg.msg_name = addr.as_ptr() as _;
412                ctrl.msg.msg_namelen = addr.len();
413            }
414            None => {
415                ctrl.msg.msg_name = std::ptr::null_mut();
416                ctrl.msg.msg_namelen = 0;
417            }
418        }
419        ctrl.msg.msg_iov = ctrl.slices.as_ptr() as _;
420        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
421        ctrl.msg.msg_control = self.control.buf_ptr() as _;
422        ctrl.msg.msg_controllen = self.control.buf_len() as _;
423    }
424}
425
426#[doc(hidden)]
427pub struct RecvMsgControl {
428    pub(crate) msg: libc::msghdr,
429    #[allow(dead_code)]
430    pub(crate) slices: Multi<SysSlice>,
431}
432
433impl Default for RecvMsgControl {
434    fn default() -> Self {
435        Self {
436            msg: unsafe { std::mem::zeroed() },
437            slices: Multi::new(),
438        }
439    }
440}
441
442impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
443    pub(crate) fn init_control(&mut self, ctrl: &mut RecvMsgControl) {
444        ctrl.slices = Multi::from_vec(self.buffer.sys_slices_mut());
445        ctrl.msg.msg_name = &raw mut self.header.addr as _;
446        ctrl.msg.msg_namelen = self.header.addr.size_of() as _;
447        ctrl.msg.msg_iov = ctrl.slices.as_mut_ptr() as _;
448        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
449        ctrl.msg.msg_control = self.control.buf_mut_ptr() as _;
450        ctrl.msg.msg_controllen = self.control.buf_capacity() as _;
451    }
452
453    pub(crate) fn update_control(&mut self, control: &RecvMsgControl) {
454        self.header.addr_len = control.msg.msg_namelen as _;
455        self.control_len = control.msg.msg_controllen as _;
456    }
457}