open_coroutine_iouring/
lib.rs

1#[cfg(target_os = "linux")]
2pub mod version;
3
4#[cfg(target_os = "linux")]
5pub mod io_uring;
6
7#[cfg(all(target_os = "linux", test))]
8mod tests {
9    use std::collections::VecDeque;
10    use std::io::{BufRead, BufReader, Write};
11    use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream};
12    use std::os::unix::io::{AsRawFd, RawFd};
13    use std::sync::atomic::{AtomicBool, Ordering};
14    use std::sync::Arc;
15    use std::time::Duration;
16    use std::{io, ptr};
17
18    use crate::io_uring::IoUringOperator;
19    use io_uring::{opcode, squeue, types, IoUring, SubmissionQueue};
20    use slab::Slab;
21
22    #[derive(Clone, Debug)]
23    enum Token {
24        Accept,
25        Poll {
26            fd: RawFd,
27        },
28        Read {
29            fd: RawFd,
30            buf_index: usize,
31        },
32        Write {
33            fd: RawFd,
34            buf_index: usize,
35            offset: usize,
36            len: usize,
37        },
38    }
39
40    pub struct AcceptCount {
41        entry: squeue::Entry,
42        count: usize,
43    }
44
45    impl AcceptCount {
46        fn new(fd: RawFd, token: usize, count: usize) -> AcceptCount {
47            AcceptCount {
48                entry: opcode::Accept::new(types::Fd(fd), ptr::null_mut(), ptr::null_mut())
49                    .build()
50                    .user_data(token as _),
51                count,
52            }
53        }
54
55        pub fn push_to(&mut self, sq: &mut SubmissionQueue<'_>) {
56            while self.count > 0 {
57                unsafe {
58                    match sq.push(&self.entry) {
59                        Ok(_) => self.count -= 1,
60                        Err(_) => break,
61                    }
62                }
63            }
64
65            sq.sync();
66        }
67    }
68
69    pub fn crate_server(port: u16, server_started: Arc<AtomicBool>) -> anyhow::Result<()> {
70        let mut ring: IoUring = IoUring::builder()
71            .setup_sqpoll(1000)
72            .setup_sqpoll_cpu(0)
73            .build(1024)?;
74        let listener = TcpListener::bind(("127.0.0.1", port))?;
75
76        let mut backlog = VecDeque::new();
77        let mut bufpool = Vec::with_capacity(64);
78        let mut buf_alloc = Slab::with_capacity(64);
79        let mut token_alloc = Slab::with_capacity(64);
80
81        println!("listen {}", listener.local_addr()?);
82        server_started.store(true, Ordering::Release);
83
84        let (submitter, mut sq, mut cq) = ring.split();
85
86        let mut accept =
87            AcceptCount::new(listener.as_raw_fd(), token_alloc.insert(Token::Accept), 1);
88
89        accept.push_to(&mut sq);
90
91        loop {
92            match submitter.submit_and_wait(1) {
93                Ok(_) => (),
94                Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => (),
95                Err(err) => return Err(err.into()),
96            }
97            cq.sync();
98
99            // clean backlog
100            loop {
101                if sq.is_full() {
102                    match submitter.submit() {
103                        Ok(_) => (),
104                        Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break,
105                        Err(err) => return Err(err.into()),
106                    }
107                }
108                sq.sync();
109
110                match backlog.pop_front() {
111                    Some(sqe) => unsafe {
112                        let _ = sq.push(&sqe);
113                    },
114                    None => break,
115                }
116            }
117
118            accept.push_to(&mut sq);
119
120            for cqe in &mut cq {
121                let ret = cqe.result();
122                let token_index = cqe.user_data() as usize;
123
124                if ret < 0 {
125                    eprintln!(
126                        "token {:?} error: {:?}",
127                        token_alloc.get(token_index),
128                        io::Error::from_raw_os_error(-ret)
129                    );
130                    continue;
131                }
132
133                let token = &mut token_alloc[token_index];
134                match token.clone() {
135                    Token::Accept => {
136                        println!("accept");
137
138                        accept.count += 1;
139
140                        let fd = ret;
141                        let poll_token = token_alloc.insert(Token::Poll { fd });
142
143                        let poll_e = opcode::PollAdd::new(types::Fd(fd), libc::POLLIN as _)
144                            .build()
145                            .user_data(poll_token as _);
146
147                        unsafe {
148                            if sq.push(&poll_e).is_err() {
149                                backlog.push_back(poll_e);
150                            }
151                        }
152                    }
153                    Token::Poll { fd } => {
154                        let (buf_index, buf) = match bufpool.pop() {
155                            Some(buf_index) => (buf_index, &mut buf_alloc[buf_index]),
156                            None => {
157                                let buf = vec![0u8; 2048].into_boxed_slice();
158                                let buf_entry = buf_alloc.vacant_entry();
159                                let buf_index = buf_entry.key();
160                                (buf_index, buf_entry.insert(buf))
161                            }
162                        };
163
164                        *token = Token::Read { fd, buf_index };
165
166                        let read_e =
167                            opcode::Recv::new(types::Fd(fd), buf.as_mut_ptr(), buf.len() as _)
168                                .build()
169                                .user_data(token_index as _);
170
171                        unsafe {
172                            if sq.push(&read_e).is_err() {
173                                backlog.push_back(read_e);
174                            }
175                        }
176                    }
177                    Token::Read { fd, buf_index } => {
178                        if ret == 0 {
179                            bufpool.push(buf_index);
180                            token_alloc.remove(token_index);
181                            println!("shutdown connection");
182                            unsafe { libc::close(fd) };
183
184                            println!("Server closed");
185                            return Ok(());
186                        } else {
187                            let len = ret as usize;
188                            let buf = &buf_alloc[buf_index];
189
190                            *token = Token::Write {
191                                fd,
192                                buf_index,
193                                len,
194                                offset: 0,
195                            };
196
197                            let write_e = opcode::Send::new(types::Fd(fd), buf.as_ptr(), len as _)
198                                .build()
199                                .user_data(token_index as _);
200
201                            unsafe {
202                                if sq.push(&write_e).is_err() {
203                                    backlog.push_back(write_e);
204                                }
205                            }
206                        }
207                    }
208                    Token::Write {
209                        fd,
210                        buf_index,
211                        offset,
212                        len,
213                    } => {
214                        let write_len = ret as usize;
215
216                        let entry = if offset + write_len >= len {
217                            bufpool.push(buf_index);
218
219                            *token = Token::Poll { fd };
220
221                            opcode::PollAdd::new(types::Fd(fd), libc::POLLIN as _)
222                                .build()
223                                .user_data(token_index as _)
224                        } else {
225                            let offset = offset + write_len;
226                            let len = len - offset;
227
228                            let buf = &buf_alloc[buf_index][offset..];
229
230                            *token = Token::Write {
231                                fd,
232                                buf_index,
233                                offset,
234                                len,
235                            };
236
237                            opcode::Write::new(types::Fd(fd), buf.as_ptr(), len as _)
238                                .build()
239                                .user_data(token_index as _)
240                        };
241
242                        unsafe {
243                            if sq.push(&entry).is_err() {
244                                backlog.push_back(entry);
245                            }
246                        }
247                    }
248                }
249            }
250        }
251    }
252
253    pub fn crate_client(port: u16, server_started: Arc<AtomicBool>) {
254        //等服务端起来
255        while !server_started.load(Ordering::Acquire) {}
256        let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port);
257        let mut stream = TcpStream::connect_timeout(&socket, Duration::from_secs(3))
258            .unwrap_or_else(|_| panic!("connect to 127.0.0.1:3456 failed !"));
259        let mut data: [u8; 512] = [b'1'; 512];
260        data[511] = b'\n';
261        let mut buffer: Vec<u8> = Vec::with_capacity(512);
262        for _ in 0..3 {
263            //写入stream流,如果写入失败,提示"写入失败"
264            assert_eq!(512, stream.write(&data).expect("Failed to write!"));
265            print!("Client Send: {}", String::from_utf8_lossy(&data[..]));
266
267            let mut reader = BufReader::new(&stream);
268            //一直读到换行为止(b'\n'中的b表示字节),读到buffer里面
269            assert_eq!(
270                512,
271                reader
272                    .read_until(b'\n', &mut buffer)
273                    .expect("Failed to read into buffer")
274            );
275            print!("Client Received: {}", String::from_utf8_lossy(&buffer[..]));
276            assert_eq!(&data, &buffer as &[u8]);
277            buffer.clear();
278        }
279        //发送终止符
280        assert_eq!(1, stream.write(&[b'e']).expect("Failed to write!"));
281        println!("client closed");
282    }
283
284    #[test]
285    fn original() -> anyhow::Result<()> {
286        let port = 8488;
287        let server_started = Arc::new(AtomicBool::new(false));
288        let clone = server_started.clone();
289        let handle = std::thread::spawn(move || crate_server(port, clone));
290        std::thread::spawn(move || crate_client(port, server_started))
291            .join()
292            .expect("client has error");
293        handle.join().expect("server has error")
294    }
295
296    pub fn crate_server2(port: u16, server_started: Arc<AtomicBool>) -> anyhow::Result<()> {
297        let operator = IoUringOperator::new(0)?;
298        let listener = TcpListener::bind(("127.0.0.1", port))?;
299
300        let mut bufpool = Vec::with_capacity(64);
301        let mut buf_alloc = Slab::with_capacity(64);
302        let mut token_alloc = Slab::with_capacity(64);
303
304        println!("listen {}", listener.local_addr()?);
305        server_started.store(true, Ordering::Release);
306
307        operator.accept(
308            token_alloc.insert(Token::Accept),
309            listener.as_raw_fd(),
310            std::ptr::null_mut(),
311            std::ptr::null_mut(),
312        )?;
313
314        loop {
315            let mut r = operator.select(None)?;
316
317            for cqe in &mut r.1 {
318                let ret = cqe.result();
319                let token_index = cqe.user_data() as usize;
320
321                if ret < 0 {
322                    eprintln!(
323                        "token {:?} error: {:?}",
324                        token_alloc.get(token_index),
325                        io::Error::from_raw_os_error(-ret)
326                    );
327                    continue;
328                }
329
330                let token = &mut token_alloc[token_index];
331                match token.clone() {
332                    Token::Accept => {
333                        println!("accept");
334
335                        let fd = ret;
336                        let poll_token = token_alloc.insert(Token::Poll { fd });
337
338                        operator.poll_add(poll_token, fd, libc::POLLIN as _)?;
339                    }
340                    Token::Poll { fd } => {
341                        let (buf_index, buf) = match bufpool.pop() {
342                            Some(buf_index) => (buf_index, &mut buf_alloc[buf_index]),
343                            None => {
344                                let buf = vec![0u8; 2048].into_boxed_slice();
345                                let buf_entry = buf_alloc.vacant_entry();
346                                let buf_index = buf_entry.key();
347                                (buf_index, buf_entry.insert(buf))
348                            }
349                        };
350
351                        *token = Token::Read { fd, buf_index };
352
353                        operator.recv(token_index, fd, buf.as_mut_ptr() as _, buf.len(), 0)?;
354                    }
355                    Token::Read { fd, buf_index } => {
356                        if ret == 0 {
357                            bufpool.push(buf_index);
358                            token_alloc.remove(token_index);
359                            println!("shutdown connection");
360                            unsafe { libc::close(fd) };
361
362                            println!("Server closed");
363                            return Ok(());
364                        } else {
365                            let len = ret as usize;
366                            let buf = &buf_alloc[buf_index];
367
368                            *token = Token::Write {
369                                fd,
370                                buf_index,
371                                len,
372                                offset: 0,
373                            };
374
375                            operator.send(token_index, fd, buf.as_ptr() as _, len, 0)?;
376                        }
377                    }
378                    Token::Write {
379                        fd,
380                        buf_index,
381                        offset,
382                        len,
383                    } => {
384                        let write_len = ret as usize;
385
386                        if offset + write_len >= len {
387                            bufpool.push(buf_index);
388
389                            *token = Token::Poll { fd };
390
391                            operator.poll_add(token_index, fd, libc::POLLIN as _)?;
392                        } else {
393                            let offset = offset + write_len;
394                            let len = len - offset;
395
396                            let buf = &buf_alloc[buf_index][offset..];
397
398                            *token = Token::Write {
399                                fd,
400                                buf_index,
401                                offset,
402                                len,
403                            };
404
405                            operator.write(token_index, fd, buf.as_ptr() as _, len)?;
406                        };
407                    }
408                }
409            }
410        }
411    }
412
413    #[test]
414    fn framework() -> anyhow::Result<()> {
415        let port = 9898;
416        let server_started = Arc::new(AtomicBool::new(false));
417        let clone = server_started.clone();
418        let handle = std::thread::spawn(move || crate_server2(port, clone));
419        std::thread::spawn(move || crate_client(port, server_started))
420            .join()
421            .expect("client has error");
422        handle.join().expect("server has error")
423    }
424}