1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
use std::io; use std::future::Future; use std::net::ToSocketAddrs; use std::os::unix::io::RawFd; use std::pin::Pin; use std::task::{Context, Poll}; use futures_core::ready; use futures_io::{AsyncRead, AsyncBufRead, AsyncWrite}; use crate::buf::Buffer; use crate::drive::demo::DemoDriver; use crate::{Drive, Ring}; use crate::event; use crate::Submission; use super::socket; pub struct TcpStream<D: Drive = DemoDriver<'static>> { ring: Ring<D>, buf: Buffer, active: Op, fd: RawFd, } #[derive(Copy, Clone, Debug, Eq, PartialEq)] enum Op { Read, Write, Close, Nothing, } impl TcpStream { pub fn connect<A: ToSocketAddrs>(addr: A) -> Connect { TcpStream::connect_on_driver(addr, DemoDriver::default()) } } impl<D: Drive + Clone> TcpStream<D> { pub fn connect_on_driver<A: ToSocketAddrs>(addr: A, driver: D) -> Connect<D> { let (fd, addr) = match socket(addr) { Ok(fd) => fd, Err(e) => return Connect(Err(Some(e))), }; Connect(Ok(Submission::new(event::Connect::new(fd, addr), driver))) } } impl<D: Drive> TcpStream<D> { pub(super) fn from_fd(fd: RawFd, ring: Ring<D>) -> TcpStream<D> { TcpStream { buf: Buffer::new(), active: Op::Nothing, fd, ring, } } fn guard_op(self: Pin<&mut Self>, op: Op) { let this = unsafe { Pin::get_unchecked_mut(self) }; if this.active != Op::Nothing && this.active != op { this.cancel(); } this.active = op; } fn cancel(&mut self) { self.active = Op::Nothing; self.ring.cancel(self.buf.cancellation()); } #[inline(always)] fn ring(self: Pin<&mut Self>) -> Pin<&mut Ring<D>> { unsafe { Pin::map_unchecked_mut(self, |this| &mut this.ring) } } #[inline(always)] fn buf(self: Pin<&mut Self>) -> Pin<&mut Buffer> { unsafe { Pin::map_unchecked_mut(self, |this| &mut this.buf) } } #[inline(always)] fn split(self: Pin<&mut Self>) -> (Pin<&mut Ring<D>>, &mut Buffer) { unsafe { let this = Pin::get_unchecked_mut(self); (Pin::new_unchecked(&mut this.ring), &mut this.buf) } } } pub struct Connect<D: Drive = DemoDriver<'static>>( Result<Submission<event::Connect, D>, Option<io::Error>> ); impl<D: Drive + Clone> Future for Connect<D> { type Output = io::Result<TcpStream<D>>; fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { unsafe { match &mut Pin::get_unchecked_mut(self).0 { Ok(submission) => { let mut submission = Pin::new_unchecked(submission); let (connect, result) = ready!(submission.as_mut().poll(ctx)); result?; let driver = submission.driver().clone(); Poll::Ready(Ok(TcpStream::from_fd(connect.fd, Ring::new(driver)))) } Err(err) => { let err = err.take().expect("polled Connect future after completion"); Poll::Ready(Err(err)) } } } } } impl<D: Drive> AsyncRead for TcpStream<D> { fn poll_read(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { let mut inner = ready!(self.as_mut().poll_fill_buf(ctx))?; let len = io::Read::read(&mut inner, buf)?; self.consume(len); Poll::Ready(Ok(len)) } } impl<D: Drive> AsyncBufRead for TcpStream<D> { fn poll_fill_buf(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { self.as_mut().guard_op(Op::Read); let fd = self.fd; let (ring, buf) = self.split(); buf.fill_buf(|buf| { let n = ready!(ring.poll(ctx, true, |sqe| unsafe { sqe.prep_read(fd, buf, 0) }))?; Poll::Ready(Ok(n as u32)) }) } fn consume(self: Pin<&mut Self>, amt: usize) { self.buf().consume(amt); } } impl<D: Drive> AsyncWrite for TcpStream<D> { fn poll_write(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, slice: &[u8]) -> Poll<io::Result<usize>> { self.as_mut().guard_op(Op::Write); let fd = self.fd; let (ring, buf) = self.split(); let data = ready!(buf.fill_buf(|mut buf| { Poll::Ready(Ok(io::Write::write(&mut buf, slice)? as u32)) }))?; let n = ready!(ring.poll(ctx, true, |sqe| unsafe { sqe.prep_write(fd, data, 0) }))?; buf.clear(); Poll::Ready(Ok(n)) } fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { ready!(self.poll_write(ctx, &[]))?; Poll::Ready(Ok(())) } fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> { self.as_mut().guard_op(Op::Close); let fd = self.fd; ready!(self.ring().poll(ctx, true, |sqe| unsafe { uring_sys::io_uring_prep_close(sqe.raw_mut(), fd) }))?; Poll::Ready(Ok(())) } }