fluke_buffet/net/
net_uring.rs1use std::{
2 mem::ManuallyDrop,
3 net::SocketAddr,
4 os::fd::{AsRawFd, FromRawFd, RawFd},
5 rc::Rc,
6};
7
8use io_uring::opcode::{Accept, Read, Write};
9use nix::errno::Errno;
10
11use crate::{
12 get_ring,
13 io::{IntoHalves, ReadOwned, WriteOwned},
14 BufResult, IoBufMut, Piece,
15};
16
17pub struct TcpStream {
18 fd: i32,
19}
20
21impl TcpStream {
22 pub async fn connect(addr: SocketAddr) -> std::io::Result<Self> {
24 let addr: socket2::SockAddr = addr.into();
25 let socket = ManuallyDrop::new(socket2::Socket::new(
26 addr.domain(),
27 socket2::Type::STREAM,
28 None,
29 )?);
30 let fd = socket.as_raw_fd();
31
32 let u = get_ring();
33
34 let addr = Box::into_raw(Box::new(addr));
35 let sqe = unsafe {
36 io_uring::opcode::Connect::new(io_uring::types::Fd(fd), addr as *const _, (*addr).len())
37 }
38 .build();
39 let cqe = u.push(sqe).await;
40 cqe.error_for_errno()?;
41 Ok(Self { fd })
42 }
43}
44
45impl Drop for TcpStream {
46 fn drop(&mut self) {
47 unsafe {
50 libc::close(self.fd);
51 }
52 }
53}
54
55pub struct TcpListener {
56 fd: i32,
57}
58
59impl TcpListener {
60 pub async fn bind(addr: SocketAddr) -> std::io::Result<Self> {
63 let addr: socket2::SockAddr = addr.into();
64 let socket = socket2::Socket::new(addr.domain(), socket2::Type::STREAM, None)?;
65 socket.bind(&addr)?;
66 socket.listen(16)?;
68 let fd = socket.as_raw_fd();
69 std::mem::forget(socket);
70
71 Ok(Self { fd })
72 }
73
74 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
75 let socket = ManuallyDrop::new(unsafe { socket2::Socket::from_raw_fd(self.fd) });
76 let addr = socket.local_addr()?;
77 Ok(addr.as_socket().unwrap())
78 }
79
80 pub async fn accept(&self) -> std::io::Result<(TcpStream, SocketAddr)> {
81 let u = get_ring();
82 struct AcceptUserData {
83 sockaddr_storage: libc::sockaddr_storage,
84 sockaddr_len: libc::socklen_t,
85 }
86 let udata = Box::into_raw(Box::new(AcceptUserData {
88 sockaddr_storage: unsafe { std::mem::zeroed() },
89 sockaddr_len: std::mem::size_of::<libc::sockaddr>() as libc::socklen_t,
90 }));
91
92 let sqe = unsafe {
93 Accept::new(
94 io_uring::types::Fd(self.fd),
95 &mut (*udata).sockaddr_storage as *mut _ as *mut _,
96 &mut (*udata).sockaddr_len,
97 )
98 .build()
99 };
100 let cqe = u.push(sqe).await;
101 let fd = cqe.error_for_errno()?;
102
103 let udata = unsafe { Box::from_raw(udata) };
104 let addr = unsafe { socket2::SockAddr::new(udata.sockaddr_storage, udata.sockaddr_len) };
105 let peer_addr = addr.as_socket().unwrap();
106
107 Ok((TcpStream { fd }, peer_addr))
108 }
109}
110
111pub struct TcpReadHalf(Rc<TcpStream>);
114
115impl ReadOwned for TcpReadHalf {
116 async fn read_owned<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
117 let sqe = Read::new(
118 io_uring::types::Fd(self.0.fd),
119 buf.io_buf_mut_stable_mut_ptr(),
120 buf.io_buf_mut_capacity() as u32,
121 )
122 .build();
123 let cqe = get_ring().push(sqe).await;
124 let ret = match cqe.error_for_errno() {
125 Ok(ret) => ret,
126 Err(e) => return (Err(std::io::Error::from(e)), buf),
127 };
128 (Ok(ret as usize), buf)
129 }
130}
131
132pub struct TcpWriteHalf(Rc<TcpStream>);
133
134impl WriteOwned for TcpWriteHalf {
135 async fn write_owned(&mut self, buf: impl Into<Piece>) -> BufResult<usize, Piece> {
136 let buf = buf.into();
137 let sqe = Write::new(
138 io_uring::types::Fd(self.0.fd),
139 buf.as_ref().as_ptr(),
140 buf.len().try_into().expect("usize -> u32"),
141 )
142 .build();
143 let cqe = get_ring().push(sqe).await;
144 let ret = match cqe.error_for_errno() {
145 Ok(ret) => ret,
146 Err(e) => return (Err(std::io::Error::from(e)), buf),
147 };
148 (Ok(ret as usize), buf)
149 }
150
151 async fn shutdown(&mut self) -> std::io::Result<()> {
154 let sqe =
155 io_uring::opcode::Shutdown::new(io_uring::types::Fd(self.0.fd), libc::SHUT_WR).build();
156 let cqe = get_ring().push(sqe).await;
157 cqe.error_for_errno()?;
158 Ok(())
159 }
160}
161
162impl IntoHalves for TcpStream {
163 type Read = TcpReadHalf;
164 type Write = TcpWriteHalf;
165
166 fn into_halves(self) -> (Self::Read, Self::Write) {
167 let self_rc = Rc::new(self);
168 (TcpReadHalf(self_rc.clone()), TcpWriteHalf(self_rc))
169 }
170}
171
172impl FromRawFd for TcpStream {
173 unsafe fn from_raw_fd(_fd: RawFd) -> Self {
174 todo!()
175 }
176}
177
178trait CqueueExt {
179 fn error_for_errno(&self) -> Result<i32, Errno>;
180}
181
182impl CqueueExt for io_uring::cqueue::Entry {
183 fn error_for_errno(&self) -> Result<i32, Errno> {
184 let res = self.result();
185 if res < 0 {
186 Err(Errno::from_raw(-res))
187 } else {
188 Ok(res as _)
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use crate::io::{IntoHalves, ReadOwned, WriteOwned};
196
197 #[test]
198 fn test_accept() {
199 color_eyre::install().unwrap();
200
201 async fn test_accept_inner() -> color_eyre::Result<()> {
202 let listener = super::TcpListener::bind("127.0.0.1:0".parse().unwrap()).await?;
203 let addr = listener.local_addr()?;
204 println!("listening on {}", addr);
205
206 std::thread::spawn(move || {
207 use std::io::{Read, Write};
208
209 let mut sock = std::net::TcpStream::connect(addr).unwrap();
210 println!(
211 "[client] connected! local={:?}, remote={:?}",
212 sock.local_addr(),
213 sock.peer_addr()
214 );
215
216 let mut buf = [0u8; 5];
217 sock.read_exact(&mut buf).unwrap();
218 println!("[client] read: {:?}", std::str::from_utf8(&buf).unwrap());
219
220 sock.write_all(b"hello").unwrap();
221 println!("[client] wrote: hello");
222 });
223
224 let (stream, addr) = listener.accept().await?;
225 println!("accepted connection!, addr={addr:?}");
226
227 let (mut r, mut w) = stream.into_halves();
228 w.write_all_owned("howdy").await?;
230
231 let buf = vec![0u8; 1024];
232 let (res, buf) = r.read_owned(buf).await;
233 let n = res?;
234 let slice = &buf[..n];
235 println!(
236 "read {} bytes: {:?}, as string: {:?}",
237 n,
238 slice,
239 std::str::from_utf8(slice)?
240 );
241
242 Ok(())
243 }
244 crate::start(async move { test_accept_inner().await.unwrap() });
245 }
246}