use core::{async_iter::AsyncIterator, future::poll_fn, mem, pin::pin};
use std::{io, net::Shutdown};
use xitca_io::{
bytes::BytesMut,
io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, write_all},
net::io_uring::TcpStream,
};
use xitca_unsafe_collection::futures::{Select, SelectOutput};
use crate::{error::Error, protocol::message::backend};
use super::generic::{DriverRx, GenericDriver, WriteState};
pub type UringDriver = GenericUringDriver<TcpStream>;
impl UringDriver {
pub(crate) fn from_tcp(drv: GenericDriver<xitca_io::net::TcpStream>) -> Self {
let GenericDriver { io, read_buf, rx, .. } = drv;
Self {
io: TcpStream::from_std(io.into_std().unwrap()),
read_buf: read_buf.into_inner(),
rx,
read_state: State::Running,
write_state: State::Running,
}
}
}
pub struct GenericUringDriver<Io> {
io: Io,
read_buf: BytesMut,
rx: DriverRx,
read_state: State,
write_state: State,
}
pub(super) enum State {
Running,
Closed(Option<io::Error>),
}
impl<Io> GenericUringDriver<Io>
where
Io: AsyncBufRead + AsyncBufWrite + 'static,
{
pub fn into_iter(mut self) -> impl AsyncIterator<Item = Result<backend::Message, Error>> + use<Io> {
let mut read_buf = mem::take(&mut self.read_buf);
async gen move {
let write = || async {
loop {
match self.rx.wait().await {
WriteState::WantWrite => {
let buf = self.rx.guarded.lock().unwrap().buf.split();
let (res, _) = write_all(&self.io, buf).await;
res?;
}
_ => return Ok::<_, io::Error>(()),
}
}
};
let read = || async gen {
loop {
match self.rx.try_decode(&mut read_buf) {
Ok(Some(msg)) => {
yield Ok(msg);
continue;
}
Err(e) => {
yield Err(SelectOutput::A(e));
return;
}
Ok(None) => {}
}
let len = read_buf.len();
read_buf.reserve(4096);
let (res, b) = self.io.read(read_buf.slice(len..)).await;
read_buf = b.into_inner();
match res {
Ok(0) => return,
Ok(_) => {}
Err(e) => {
yield Err(SelectOutput::B(e));
return;
}
}
}
};
let mut read = pin!(read());
let mut write = pin!(write());
loop {
let res = match (&mut self.write_state, &mut self.read_state) {
(State::Running, State::Running) => {
write.as_mut().select(poll_fn(|cx| read.as_mut().poll_next(cx))).await
}
(State::Running, _) => SelectOutput::A(write.as_mut().await),
(_, State::Running) => SelectOutput::B(poll_fn(|cx| read.as_mut().poll_next(cx)).await),
(State::Closed(None), State::Closed(None)) => {
if let Err(e) = self.io.shutdown(Shutdown::Both) {
yield Err(e.into());
}
return;
}
(State::Closed(err_w), State::Closed(err_r)) => {
yield Err(Error::driver_io(err_r.take(), err_w.take()));
return;
}
};
match res {
SelectOutput::A(Ok(_)) => self.write_state = State::Closed(None),
SelectOutput::A(Err(e)) => self.write_state = State::Closed(Some(e)),
SelectOutput::B(Some(Ok(msg))) => {
yield Ok(msg);
}
SelectOutput::B(Some(Err(e))) => match e {
SelectOutput::A(e) => {
yield Err(e);
return;
}
SelectOutput::B(e) => {
self.read_state = State::Closed(Some(e));
}
},
SelectOutput::B(None) => {
self.read_state = State::Closed(None);
}
}
}
}
}
}
#[cfg(not(feature = "tls"))]
#[cfg(test)]
mod test {
use core::{future::poll_fn, pin::pin};
use crate::{Execute, Postgres, Statement, iter::AsyncLendingIterator};
use super::*;
#[tokio::test]
async fn io_uring_drv() {
let (conn, drv) = Postgres::new("postgres://postgres:postgres@localhost:5432")
.connect()
.await
.unwrap();
let handle = std::thread::spawn(move || {
tokio_uring_xitca::start(async move {
let mut drv = pin!(drv.try_into_uring().unwrap().into_iter());
while poll_fn(|cx| drv.as_mut().poll_next(cx)).await.is_some() {}
})
});
let num = Statement::named("SELECT 1", &[])
.execute(&conn)
.await
.unwrap()
.query(&conn)
.await
.unwrap()
.try_next()
.await
.unwrap()
.unwrap()
.get::<i32>(0);
assert_eq!(num, 1);
drop(conn);
handle.join().unwrap()
}
}