pub mod uring;
use std::cell::RefCell;
use std::future::Future;
use std::os::unix::io::FromRawFd;
use std::rc::Rc;
use std::{io, net, ops, time};
use futures::executor;
use futures::task::LocalSpawn;
use nix::sys::socket;
pub type Buffer = Box<[u8]>;
pub struct Runner {
uring: Rc<RefCell<uring::Uring>>,
pool: executor::LocalPool,
}
impl Runner {
pub fn new(size: u32) -> anyhow::Result<Self> {
let uring = Rc::new(RefCell::new(uring::Uring::new(size)?));
let pool = executor::LocalPool::new();
Ok(Self { uring, pool })
}
pub fn spawn<C, F>(&mut self, c: C)
where
C: FnOnce(Context) -> F,
F: Future<Output = ()> + 'static,
{
let context = Context::new(self.uring.clone(), self.pool.spawner());
let f = c(context);
self.pool
.spawner()
.spawn_local_obj(Box::pin(f).into())
.expect("failed to spawn");
}
pub fn run(&mut self) -> anyhow::Result<()> {
loop {
self.pool.run_until_stalled();
let mut uring = self.uring.borrow_mut();
let count = uring.submit_and_wait()?;
if count == 0 {
anyhow::bail!("no more pending IO");
}
}
}
}
#[derive(Clone)]
pub struct Context {
uring: Rc<RefCell<uring::Uring>>,
spawner: executor::LocalSpawner,
}
impl Context {
pub fn new(uring: Rc<RefCell<uring::Uring>>, spawner: executor::LocalSpawner) -> Self {
Self { uring, spawner }
}
pub fn spawn<C, F>(&mut self, c: C)
where
C: FnOnce(Context) -> F,
F: Future<Output = ()> + 'static,
{
let f = c(self.clone());
self.spawner
.spawn_local_obj(Box::pin(f).into())
.expect("failed to spawn");
}
pub async fn accept(
&mut self,
socket: net::TcpListener,
) -> io::Result<(net::TcpListener, net::TcpStream)> {
let submission = uring::Submission::Accept { socket };
let task = self
.uring
.borrow_mut()
.submit(submission, uring::Flags::empty());
match task.await? {
uring::Completion::Accept { socket, stream } => Ok((socket, stream)),
_ => unreachable!(),
}
}
pub fn buffer(&mut self, size: usize) -> Buffer {
vec![0; size].into_boxed_slice()
}
pub async fn connect<A>(&mut self, addr: A) -> anyhow::Result<net::TcpStream>
where
A: net::ToSocketAddrs,
{
let fd = socket::socket(
socket::AddressFamily::Inet,
socket::SockType::Stream,
socket::SockFlag::empty(),
socket::SockProtocol::Tcp,
)?;
let socket = unsafe { net::TcpStream::from_raw_fd(fd) };
let addr = addr.to_socket_addrs()?.next().unwrap();
let submission = uring::Submission::Connect { socket, addr };
let task = self
.uring
.borrow_mut()
.submit(submission, uring::Flags::empty());
match task.await? {
uring::Completion::Connect { stream } => Ok(stream),
_ => unreachable!(),
}
}
pub fn listen<A>(&mut self, addr: A) -> io::Result<net::TcpListener>
where
A: net::ToSocketAddrs,
{
net::TcpListener::bind(addr)
}
pub async fn read<R>(
&mut self,
stream: net::TcpStream,
buffer: Buffer,
range: R,
) -> io::Result<(net::TcpStream, Buffer, ops::Range<usize>)>
where
R: ops::RangeBounds<usize>,
{
let range = range_bounds(range, 0, buffer.len());
let submission = uring::Submission::Read {
stream,
buffer,
range,
};
let task = self
.uring
.borrow_mut()
.submit(submission, uring::Flags::empty());
match task.await? {
uring::Completion::Read {
stream,
buffer,
range,
} => Ok((stream, buffer, range)),
_ => unreachable!(),
}
}
pub async fn read_full<R>(
&mut self,
mut stream: net::TcpStream,
mut buffer: Buffer,
range: R,
) -> io::Result<(net::TcpStream, Buffer, ops::Range<usize>)>
where
R: ops::RangeBounds<usize>,
{
let mut range = range_bounds(range, 0, buffer.len());
let goal = range.start..range.end;
loop {
let (s, b, actual) = self.read(stream, buffer, range).await?;
if actual.start == actual.end || actual.end == goal.end {
return Ok((s, b, goal.start..actual.end));
}
stream = s;
buffer = b;
range = actual.end..goal.end;
}
}
pub async fn write<R>(
&mut self,
stream: net::TcpStream,
buffer: Buffer,
range: R,
) -> io::Result<(net::TcpStream, Buffer, ops::Range<usize>)>
where
R: ops::RangeBounds<usize>,
{
let range = range_bounds(range, 0, buffer.len());
let submission = uring::Submission::Write {
stream,
buffer,
range,
};
let task = self
.uring
.borrow_mut()
.submit(submission, uring::Flags::empty());
match task.await? {
uring::Completion::Write {
stream,
buffer,
range,
} => Ok((stream, buffer, range)),
_ => unreachable!(),
}
}
pub async fn write_then<R>(
&mut self,
stream: net::TcpStream,
buffer: Buffer,
range: R,
) -> io::Result<(net::TcpStream, Buffer, ops::Range<usize>)>
where
R: ops::RangeBounds<usize>,
{
let range = range_bounds(range, 0, buffer.len());
let submission = uring::Submission::Write {
stream,
buffer,
range,
};
let task = self
.uring
.borrow_mut()
.submit(submission, uring::Flags::IO_LINK);
match task.await? {
uring::Completion::Write {
stream,
buffer,
range,
} => Ok((stream, buffer, range)),
_ => unreachable!(),
}
}
pub async fn write_full<R>(
&mut self,
mut stream: net::TcpStream,
mut buffer: Buffer,
range: R,
) -> io::Result<(net::TcpStream, Buffer)>
where
R: ops::RangeBounds<usize>,
{
let mut range = range_bounds(range, 0, buffer.len());
let goal = range.end;
loop {
let (s, b, actual) = self.write(stream, buffer, range).await?;
if actual.end == goal {
return Ok((s, b));
}
stream = s;
buffer = b;
range = actual.end..goal;
}
}
pub fn cancel(&mut self) {
}
pub fn timeout(&mut self, _expires: time::Duration) -> Self {
self.clone()
}
}
pub struct AcceptResult {
pub socket: net::TcpListener,
pub stream: net::TcpStream,
}
pub struct ReadResult {
pub stream: net::TcpStream,
pub buffer: Buffer,
pub range: ops::Range<usize>,
}
pub struct WriteResult {
pub stream: net::TcpStream,
pub buffer: Buffer,
pub range: ops::Range<usize>,
}
fn range_bounds<R>(range: R, min: usize, max: usize) -> ops::Range<usize>
where
R: ops::RangeBounds<usize>,
{
let start = match range.start_bound() {
ops::Bound::Included(n) => *n,
ops::Bound::Excluded(n) => n + 1,
ops::Bound::Unbounded => min,
};
let end = match range.end_bound() {
ops::Bound::Included(n) => n + 1,
ops::Bound::Excluded(n) => *n,
ops::Bound::Unbounded => max,
};
start..end
}