fluke_buffet/net/
net_uring.rs

1use std::{
2    mem::ManuallyDrop,
3    net::SocketAddr,
4    os::fd::{AsRawFd, FromRawFd, RawFd},
5    rc::Rc,
6};
7
8use io_uring::opcode::{Accept, Read, Write};
9use nix::errno::Errno;
10
11use crate::{
12    get_ring,
13    io::{IntoHalves, ReadOwned, WriteOwned},
14    BufResult, IoBufMut, Piece,
15};
16
17pub struct TcpStream {
18    fd: i32,
19}
20
21impl TcpStream {
22    // TODO: nodelay
23    pub async fn connect(addr: SocketAddr) -> std::io::Result<Self> {
24        let addr: socket2::SockAddr = addr.into();
25        let socket = ManuallyDrop::new(socket2::Socket::new(
26            addr.domain(),
27            socket2::Type::STREAM,
28            None,
29        )?);
30        let fd = socket.as_raw_fd();
31
32        let u = get_ring();
33
34        let addr = Box::into_raw(Box::new(addr));
35        let sqe = unsafe {
36            io_uring::opcode::Connect::new(io_uring::types::Fd(fd), addr as *const _, (*addr).len())
37        }
38        .build();
39        let cqe = u.push(sqe).await;
40        cqe.error_for_errno()?;
41        Ok(Self { fd })
42    }
43}
44
45impl Drop for TcpStream {
46    fn drop(&mut self) {
47        // TODO: rethink this.
48        // what about all the in-flight operations?
49        unsafe {
50            libc::close(self.fd);
51        }
52    }
53}
54
55pub struct TcpListener {
56    fd: i32,
57}
58
59impl TcpListener {
60    // note: this is only async to match tokio's API
61    // TODO: investigate why tokio's TcpListener::bind is async
62    pub async fn bind(addr: SocketAddr) -> std::io::Result<Self> {
63        let addr: socket2::SockAddr = addr.into();
64        let socket = socket2::Socket::new(addr.domain(), socket2::Type::STREAM, None)?;
65        socket.bind(&addr)?;
66        // FIXME: magic values
67        socket.listen(16)?;
68        let fd = socket.as_raw_fd();
69        std::mem::forget(socket);
70
71        Ok(Self { fd })
72    }
73
74    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
75        let socket = ManuallyDrop::new(unsafe { socket2::Socket::from_raw_fd(self.fd) });
76        let addr = socket.local_addr()?;
77        Ok(addr.as_socket().unwrap())
78    }
79
80    pub async fn accept(&self) -> std::io::Result<(TcpStream, SocketAddr)> {
81        let u = get_ring();
82        struct AcceptUserData {
83            sockaddr_storage: libc::sockaddr_storage,
84            sockaddr_len: libc::socklen_t,
85        }
86        // FIXME: this currently leaks if the future is dropped
87        let udata = Box::into_raw(Box::new(AcceptUserData {
88            sockaddr_storage: unsafe { std::mem::zeroed() },
89            sockaddr_len: std::mem::size_of::<libc::sockaddr>() as libc::socklen_t,
90        }));
91
92        let sqe = unsafe {
93            Accept::new(
94                io_uring::types::Fd(self.fd),
95                &mut (*udata).sockaddr_storage as *mut _ as *mut _,
96                &mut (*udata).sockaddr_len,
97            )
98            .build()
99        };
100        let cqe = u.push(sqe).await;
101        let fd = cqe.error_for_errno()?;
102
103        let udata = unsafe { Box::from_raw(udata) };
104        let addr = unsafe { socket2::SockAddr::new(udata.sockaddr_storage, udata.sockaddr_len) };
105        let peer_addr = addr.as_socket().unwrap();
106
107        Ok((TcpStream { fd }, peer_addr))
108    }
109}
110
111// TODO: fix about the lifetime of TcpStream, closing
112// the underlying fd, in-flight operations etc.
113pub struct TcpReadHalf(Rc<TcpStream>);
114
115impl ReadOwned for TcpReadHalf {
116    async fn read_owned<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
117        let sqe = Read::new(
118            io_uring::types::Fd(self.0.fd),
119            buf.io_buf_mut_stable_mut_ptr(),
120            buf.io_buf_mut_capacity() as u32,
121        )
122        .build();
123        let cqe = get_ring().push(sqe).await;
124        let ret = match cqe.error_for_errno() {
125            Ok(ret) => ret,
126            Err(e) => return (Err(std::io::Error::from(e)), buf),
127        };
128        (Ok(ret as usize), buf)
129    }
130}
131
132pub struct TcpWriteHalf(Rc<TcpStream>);
133
134impl WriteOwned for TcpWriteHalf {
135    async fn write_owned(&mut self, buf: impl Into<Piece>) -> BufResult<usize, Piece> {
136        let buf = buf.into();
137        let sqe = Write::new(
138            io_uring::types::Fd(self.0.fd),
139            buf.as_ref().as_ptr(),
140            buf.len().try_into().expect("usize -> u32"),
141        )
142        .build();
143        let cqe = get_ring().push(sqe).await;
144        let ret = match cqe.error_for_errno() {
145            Ok(ret) => ret,
146            Err(e) => return (Err(std::io::Error::from(e)), buf),
147        };
148        (Ok(ret as usize), buf)
149    }
150
151    // TODO: implement writev
152
153    async fn shutdown(&mut self) -> std::io::Result<()> {
154        let sqe =
155            io_uring::opcode::Shutdown::new(io_uring::types::Fd(self.0.fd), libc::SHUT_WR).build();
156        let cqe = get_ring().push(sqe).await;
157        cqe.error_for_errno()?;
158        Ok(())
159    }
160}
161
162impl IntoHalves for TcpStream {
163    type Read = TcpReadHalf;
164    type Write = TcpWriteHalf;
165
166    fn into_halves(self) -> (Self::Read, Self::Write) {
167        let self_rc = Rc::new(self);
168        (TcpReadHalf(self_rc.clone()), TcpWriteHalf(self_rc))
169    }
170}
171
172impl FromRawFd for TcpStream {
173    unsafe fn from_raw_fd(_fd: RawFd) -> Self {
174        todo!()
175    }
176}
177
178trait CqueueExt {
179    fn error_for_errno(&self) -> Result<i32, Errno>;
180}
181
182impl CqueueExt for io_uring::cqueue::Entry {
183    fn error_for_errno(&self) -> Result<i32, Errno> {
184        let res = self.result();
185        if res < 0 {
186            Err(Errno::from_raw(-res))
187        } else {
188            Ok(res as _)
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use crate::io::{IntoHalves, ReadOwned, WriteOwned};
196
197    #[test]
198    fn test_accept() {
199        color_eyre::install().unwrap();
200
201        async fn test_accept_inner() -> color_eyre::Result<()> {
202            let listener = super::TcpListener::bind("127.0.0.1:0".parse().unwrap()).await?;
203            let addr = listener.local_addr()?;
204            println!("listening on {}", addr);
205
206            std::thread::spawn(move || {
207                use std::io::{Read, Write};
208
209                let mut sock = std::net::TcpStream::connect(addr).unwrap();
210                println!(
211                    "[client] connected! local={:?}, remote={:?}",
212                    sock.local_addr(),
213                    sock.peer_addr()
214                );
215
216                let mut buf = [0u8; 5];
217                sock.read_exact(&mut buf).unwrap();
218                println!("[client] read: {:?}", std::str::from_utf8(&buf).unwrap());
219
220                sock.write_all(b"hello").unwrap();
221                println!("[client] wrote: hello");
222            });
223
224            let (stream, addr) = listener.accept().await?;
225            println!("accepted connection!, addr={addr:?}");
226
227            let (mut r, mut w) = stream.into_halves();
228            // write bye
229            w.write_all_owned("howdy").await?;
230
231            let buf = vec![0u8; 1024];
232            let (res, buf) = r.read_owned(buf).await;
233            let n = res?;
234            let slice = &buf[..n];
235            println!(
236                "read {} bytes: {:?}, as string: {:?}",
237                n,
238                slice,
239                std::str::from_utf8(slice)?
240            );
241
242            Ok(())
243        }
244        crate::start(async move { test_accept_inner().await.unwrap() });
245    }
246}