Skip to main content

compio_driver/sys/op/socket/
unix.rs

1use std::{net::Shutdown, num::NonZeroU32};
2
3use rustix::{
4    io::close,
5    net::{
6        AddressFamily, Protocol, RecvAncillaryBuffer, SendAncillaryBuffer, SocketAddrAny,
7        SocketType, acceptfrom_with, bind, connect, listen, recv, recvfrom, recvmsg, send, sendmsg,
8        sendmsg_addr, sendto, shutdown, socket_with,
9    },
10};
11
12use crate::{PollFirst, sys::op::*};
13
14impl<S: AsFd> Accept<S> {
15    pub(crate) fn call(&mut self) -> io::Result<usize> {
16        let (owned, addr) = acceptfrom_with(self.fd.as_fd(), SOCKET_FLAG)?;
17        let fd = owned.as_raw_fd();
18        let socket: Socket2 = owned.into();
19
20        if cfg!(apple) {
21            socket.set_cloexec(true)?;
22            socket.set_nonblocking(true)?;
23        }
24
25        copy_addr_from(&mut self.buffer, &mut self.addr_len, addr);
26        self.accepted_fd = Some(socket);
27
28        Ok(fd as usize)
29    }
30}
31
32impl<S: AsFd> Connect<S> {
33    pub(crate) fn call(&self) -> io::Result<usize> {
34        connect(&self.fd, &SockAddrArg(&self.addr))?;
35        Ok(0)
36    }
37}
38
39impl<T: IoBuf, S: AsFd> Send<T, S> {
40    pub(crate) fn call(&mut self) -> io::Result<usize> {
41        send(self.fd.as_fd(), self.buffer.as_init(), self.flags).map_err(Into::into)
42    }
43}
44
45impl<T: IoBuf, S: AsFd> SendTo<T, S> {
46    pub(crate) fn call(&self) -> io::Result<usize> {
47        sendto(
48            self.header.fd.as_fd(),
49            self.buffer.as_init(),
50            self.header.flags,
51            &SockAddrArg(&self.header.addr),
52        )
53        .map_err(Into::into)
54    }
55}
56
57impl<T: IoVectoredBuf, S: AsFd> SendVectored<T, S> {
58    pub(crate) fn call(&self, control: &mut SendVectoredControl) -> io::Result<usize> {
59        let mut anc = SendAncillaryBuffer::default();
60
61        sendmsg(
62            self.fd.as_fd(),
63            io_slice(&control.slices),
64            &mut anc,
65            self.flags,
66        )
67        .map_err(Into::into)
68    }
69}
70
71impl<T: IoVectoredBuf, S: AsFd> SendToVectored<T, S> {
72    pub(crate) fn call(&mut self, control: &mut SendMsgControl) -> io::Result<usize> {
73        let addr = SockAddrArg(&self.header.addr);
74        let mut anc = SendAncillaryBuffer::default();
75        let buf = io_slice(&control.slices);
76
77        sendmsg_addr(
78            self.header.fd.as_fd(),
79            &addr,
80            buf,
81            &mut anc,
82            self.header.flags,
83        )
84        .map_err(Into::into)
85    }
86}
87
88impl<T: IoVectoredBuf, C: IoBuf, S: AsFd> SendMsg<T, C, S> {
89    pub(crate) fn call(&mut self, control: &mut SendMsgControl) -> io::Result<usize> {
90        // Both rustix and nix expose api that uses structured AncillaryBuffer
91        // building, no way to just throw in an ancillary buf. Fallback to libc here.
92        syscall!(libc::sendmsg(
93            self.fd.as_fd().as_raw_fd(),
94            &control.msg,
95            self.flags.bits() as _,
96        ))
97    }
98}
99
100impl<T: IoBufMut, S: AsFd> Recv<T, S> {
101    pub(crate) fn call(&mut self) -> io::Result<usize> {
102        let (_, len) = recv(self.fd.as_fd(), self.buffer.as_uninit(), self.flags)?;
103
104        Ok(len)
105    }
106}
107
108impl<T: IoVectoredBufMut, S: AsFd> RecvVectored<T, S> {
109    pub(crate) fn call(&mut self, control: &mut RecvVectoredControl) -> io::Result<usize> {
110        let res = recvmsg(
111            self.fd.as_fd(),
112            io_slice_mut(&mut control.slices),
113            &mut RecvAncillaryBuffer::default(),
114            self.flags,
115        )?;
116
117        // Kernel may truncate and return a larger-than-buffer size
118        Ok(res.bytes.min(self.buffer.total_capacity()))
119    }
120}
121
122impl<S: AsFd> RecvFromHeader<S> {
123    pub fn set_addr(&mut self, addr: Option<SocketAddrAny>) {
124        copy_addr_from(&mut self.addr, &mut self.addr_len, addr)
125    }
126}
127
128impl<T: IoBufMut, S: AsFd> RecvFrom<T, S> {
129    pub(crate) fn call(&mut self) -> io::Result<usize> {
130        let (_, len, addr) = recvfrom(&self.header.fd, self.buffer.as_uninit(), self.header.flags)?;
131
132        self.header.set_addr(addr);
133
134        Ok(len.min(self.buffer.buf_capacity()))
135    }
136}
137
138impl<T: IoVectoredBufMut, C: IoBufMut, S: AsFd> RecvMsg<T, C, S> {
139    pub(crate) fn call(&mut self, control: &mut RecvMsgControl) -> io::Result<usize> {
140        let res = syscall!(libc::recvmsg(
141            self.header.fd.as_fd().as_raw_fd(),
142            &raw mut control.msg,
143            self.header.flags.bits() as _,
144        ))?;
145
146        self.update_control(control);
147
148        Ok(res)
149    }
150}
151
152impl<T: IoVectoredBufMut, S: AsFd> RecvFromVectored<T, S> {
153    pub(crate) fn call(&mut self, control: &mut RecvMsgControl) -> io::Result<usize> {
154        let res = recvmsg(
155            &self.header.fd,
156            io_slice_mut(&mut control.slices),
157            &mut RecvAncillaryBuffer::default(),
158            self.header.flags,
159        )?;
160
161        self.header.set_addr(res.address);
162
163        Ok(res.bytes)
164    }
165}
166
167/// Create a socket.
168pub struct CreateSocket {
169    pub(crate) domain: AddressFamily,
170    pub(crate) socket_type: SocketType,
171    pub(crate) protocol: Option<Protocol>,
172    pub(crate) opened_fd: Option<Socket2>,
173}
174
175impl CreateSocket {
176    /// Create [`CreateSocket`].
177    pub fn new(domain: i32, socket_type: i32, protocol: i32) -> Self {
178        let domain = AddressFamily::from_raw(domain as _);
179        let socket_type = SocketType::from_raw(socket_type as _);
180        let protocol = NonZeroU32::new(protocol as _).map(Protocol::from_raw);
181
182        Self {
183            domain,
184            socket_type,
185            protocol,
186            opened_fd: None,
187        }
188    }
189
190    pub(crate) fn call(&mut self) -> io::Result<usize> {
191        let owned = socket_with(self.domain, self.socket_type, SOCKET_FLAG, self.protocol)?;
192        let fd = owned.as_raw_fd();
193        let socket: Socket2 = owned.into();
194
195        #[cfg(apple)]
196        {
197            socket.set_cloexec(true)?;
198            socket.set_nosigpipe(true)?;
199            socket.set_nonblocking(true)?;
200        }
201
202        self.opened_fd = Some(socket);
203        Ok(fd as _)
204    }
205}
206
207impl IntoInner for CreateSocket {
208    type Inner = Socket2;
209
210    fn into_inner(self) -> Self::Inner {
211        self.opened_fd.expect("socket not created")
212    }
213}
214
215/// Bind a socket to an address.
216pub struct Bind<S> {
217    pub(crate) fd: S,
218    pub(crate) addr: SockAddr,
219}
220
221impl<S> Bind<S> {
222    /// Create [`Bind`].
223    pub fn new(fd: S, addr: SockAddr) -> Self {
224        Self { fd, addr }
225    }
226}
227
228impl<S: AsFd> Bind<S> {
229    pub(crate) fn call(&self) -> io::Result<usize> {
230        bind(self.fd.as_fd(), &SockAddrArg(&self.addr))?;
231        Ok(0)
232    }
233}
234
235/// Listen for connections on a socket.
236pub struct Listen<S> {
237    pub(crate) fd: S,
238    pub(crate) backlog: i32,
239}
240
241impl<S> Listen<S> {
242    /// Create [`Listen`].
243    pub fn new(fd: S, backlog: i32) -> Self {
244        Self { fd, backlog }
245    }
246}
247
248impl<S: AsFd> Listen<S> {
249    pub(crate) fn call(&self) -> io::Result<usize> {
250        listen(self.fd.as_fd(), self.backlog)?;
251        Ok(0)
252    }
253}
254
255/// Shutdown a socket.
256pub struct ShutdownSocket<S> {
257    pub(crate) fd: S,
258    pub(crate) how: Shutdown,
259}
260
261impl<S> ShutdownSocket<S> {
262    /// Create [`ShutdownSocket`].
263    pub fn new(fd: S, how: Shutdown) -> Self {
264        Self { fd, how }
265    }
266}
267
268impl<S: AsFd> ShutdownSocket<S> {
269    #[cfg(io_uring)]
270    pub(crate) fn how(&self) -> i32 {
271        match self.how {
272            Shutdown::Write => libc::SHUT_WR,
273            Shutdown::Read => libc::SHUT_RD,
274            Shutdown::Both => libc::SHUT_RDWR,
275        }
276    }
277
278    pub(crate) fn call(&mut self) -> io::Result<usize> {
279        let how = match self.how {
280            Shutdown::Write => rustix::net::Shutdown::Write,
281            Shutdown::Read => rustix::net::Shutdown::Read,
282            Shutdown::Both => rustix::net::Shutdown::Both,
283        };
284        shutdown(&self.fd, how)?;
285        Ok(0)
286    }
287}
288
289impl CloseSocket {
290    pub(crate) fn call(&mut self) -> io::Result<usize> {
291        unsafe { close(self.fd.as_raw_fd()) };
292        Ok(0)
293    }
294}
295
296/// Accept a connection.
297pub struct Accept<S> {
298    pub(crate) fd: S,
299    pub(crate) buffer: SockAddrStorage,
300    pub(crate) addr_len: socklen_t,
301    pub(crate) accepted_fd: Option<Socket2>,
302    pub(crate) poll_first: bool,
303}
304
305impl<S> Accept<S> {
306    /// Create [`Accept`].
307    pub fn new(fd: S) -> Self {
308        let buffer = SockAddrStorage::zeroed();
309        let addr_len = buffer.size_of();
310        Self {
311            fd,
312            buffer,
313            addr_len,
314            accepted_fd: None,
315            poll_first: false,
316        }
317    }
318}
319
320impl<S> PollFirst for Accept<S> {
321    fn poll_first(&mut self) {
322        self.poll_first = true;
323    }
324}
325
326impl<S> IntoInner for Accept<S> {
327    type Inner = (Socket2, SockAddr);
328
329    fn into_inner(mut self) -> Self::Inner {
330        let socket = self.accepted_fd.take().expect("socket not accepted");
331        (socket, unsafe { SockAddr::new(self.buffer, self.addr_len) })
332    }
333}
334
335#[doc(hidden)]
336pub struct RecvVectoredControl {
337    pub(crate) msg: libc::msghdr,
338    #[allow(dead_code)]
339    pub(crate) slices: Vec<SysSlice>,
340}
341
342impl Default for RecvVectoredControl {
343    fn default() -> Self {
344        Self {
345            msg: unsafe { std::mem::zeroed() },
346            slices: Vec::new(),
347        }
348    }
349}
350
351impl<T: IoVectoredBufMut, S> RecvVectored<T, S> {
352    pub(crate) fn init_control(&mut self, ctrl: &mut RecvVectoredControl) {
353        ctrl.slices = self.buffer.sys_slices_mut();
354        ctrl.msg.msg_iov = ctrl.slices.as_mut_ptr() as _;
355        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
356    }
357}
358
359#[doc(hidden)]
360pub struct SendVectoredControl {
361    pub(crate) msg: libc::msghdr,
362    #[allow(dead_code)]
363    pub(crate) slices: Vec<SysSlice>,
364}
365
366impl Default for SendVectoredControl {
367    fn default() -> Self {
368        Self {
369            msg: unsafe { std::mem::zeroed() },
370            slices: Vec::new(),
371        }
372    }
373}
374
375impl<T: IoVectoredBuf, S> SendVectored<T, S> {
376    pub(crate) fn init_control(&mut self, ctrl: &mut SendVectoredControl) {
377        ctrl.slices = self.buffer.sys_slices();
378        ctrl.msg.msg_iov = ctrl.slices.as_ptr() as _;
379        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
380    }
381}
382
383#[doc(hidden)]
384pub struct SendMsgControl {
385    pub(crate) msg: libc::msghdr,
386    #[allow(dead_code)]
387    pub(crate) slices: Multi<SysSlice>,
388}
389
390impl<S: AsFd> SendToHeader<S> {
391    #[allow(dead_code)]
392    pub(crate) fn create_control(
393        &mut self,
394        ctrl: &mut SendMsgControl,
395        slices: impl Into<Multi<SysSlice>>,
396    ) {
397        ctrl.msg.msg_name = self.addr.as_ptr() as _;
398        ctrl.msg.msg_namelen = self.addr.len();
399        ctrl.slices = slices.into();
400        ctrl.msg.msg_iov = ctrl.slices.as_mut_ptr() as _;
401        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
402    }
403}
404
405impl Default for SendMsgControl {
406    fn default() -> Self {
407        Self {
408            msg: unsafe { std::mem::zeroed() },
409            slices: Multi::new(),
410        }
411    }
412}
413
414impl<T: IoVectoredBuf, C: IoBuf, S> SendMsg<T, C, S> {
415    pub(crate) fn init_control(&mut self, ctrl: &mut SendMsgControl) {
416        ctrl.slices = self.buffer.sys_slices().into();
417        match self.addr.as_ref() {
418            Some(addr) => {
419                ctrl.msg.msg_name = addr.as_ptr() as _;
420                ctrl.msg.msg_namelen = addr.len();
421            }
422            None => {
423                ctrl.msg.msg_name = std::ptr::null_mut();
424                ctrl.msg.msg_namelen = 0;
425            }
426        }
427        ctrl.msg.msg_iov = ctrl.slices.as_ptr() as _;
428        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
429        ctrl.msg.msg_control = self.control.buf_ptr() as _;
430        ctrl.msg.msg_controllen = self.control.buf_len() as _;
431    }
432}
433
434#[doc(hidden)]
435pub struct RecvMsgControl {
436    pub(crate) msg: libc::msghdr,
437    #[allow(dead_code)]
438    pub(crate) slices: Multi<SysSlice>,
439}
440
441impl Default for RecvMsgControl {
442    fn default() -> Self {
443        Self {
444            msg: unsafe { std::mem::zeroed() },
445            slices: Multi::new(),
446        }
447    }
448}
449
450impl<T: IoVectoredBufMut, C: IoBufMut, S> RecvMsg<T, C, S> {
451    pub(crate) fn init_control(&mut self, ctrl: &mut RecvMsgControl) {
452        ctrl.slices = Multi::from_vec(self.buffer.sys_slices_mut());
453        ctrl.msg.msg_name = &raw mut self.header.addr as _;
454        ctrl.msg.msg_namelen = self.header.addr.size_of() as _;
455        ctrl.msg.msg_iov = ctrl.slices.as_mut_ptr() as _;
456        ctrl.msg.msg_iovlen = ctrl.slices.len() as _;
457        ctrl.msg.msg_control = self.control.buf_mut_ptr() as _;
458        ctrl.msg.msg_controllen = self.control.buf_capacity() as _;
459    }
460
461    pub(crate) fn update_control(&mut self, control: &RecvMsgControl) {
462        self.header.addr_len = control.msg.msg_namelen as _;
463        self.control_len = control.msg.msg_controllen as _;
464        self.return_flags = ReturnFlags::from_bits_retain(control.msg.msg_flags as _);
465    }
466}