use std::{fmt, fs::File, io, marker::PhantomData, path::Path, sync::mpsc, thread};
use crate::{Error, RawMem, file_mapped::FileMapped};
enum Command<T: Copy + Send + 'static> {
Len(tokio::sync::oneshot::Sender<usize>),
Get(usize, tokio::sync::oneshot::Sender<Option<T>>),
Set(usize, T, tokio::sync::oneshot::Sender<bool>),
ReadSlice(usize, usize, tokio::sync::oneshot::Sender<Result<Vec<T>, Error>>),
WriteSlice(usize, Vec<T>, tokio::sync::oneshot::Sender<Result<(), Error>>),
GrowFilled(usize, T, tokio::sync::oneshot::Sender<Result<(), Error>>),
GrowZeroed(usize, tokio::sync::oneshot::Sender<Result<(), Error>>),
GrowAssumed(usize, tokio::sync::oneshot::Sender<Result<(), Error>>),
Shrink(usize, tokio::sync::oneshot::Sender<Result<(), Error>>),
Shutdown(tokio::sync::oneshot::Sender<()>),
}
pub struct AsyncFileMem<T: Copy + Send + 'static> {
tx: Option<mpsc::Sender<Command<T>>>,
thread: Option<thread::JoinHandle<()>>,
_marker: PhantomData<T>,
}
fn spawn_io_thread<T: Copy + Clone + Send + 'static>(
mut mem: FileMapped<T>,
rx: mpsc::Receiver<Command<T>>,
) -> thread::JoinHandle<()> {
thread::spawn(move || {
while let Ok(cmd) = rx.recv() {
match cmd {
Command::Len(reply) => {
let _ = reply.send(mem.allocated().len());
}
Command::Get(index, reply) => {
let val = mem.allocated().get(index).copied();
let _ = reply.send(val);
}
Command::Set(index, value, reply) => {
let slice = mem.allocated_mut();
let ok = if index < slice.len() {
slice[index] = value;
true
} else {
false
};
let _ = reply.send(ok);
}
Command::ReadSlice(offset, count, reply) => {
let slice = mem.allocated();
let result = if offset.saturating_add(count) <= slice.len() {
Ok(slice[offset..offset + count].to_vec())
} else {
Err(Error::CapacityOverflow)
};
let _ = reply.send(result);
}
Command::WriteSlice(offset, values, reply) => {
let slice = mem.allocated_mut();
let result = if offset.saturating_add(values.len()) <= slice.len() {
slice[offset..offset + values.len()].copy_from_slice(&values);
Ok(())
} else {
Err(Error::CapacityOverflow)
};
let _ = reply.send(result);
}
Command::GrowFilled(count, value, reply) => {
let result = mem.grow_filled(count, value).map(|_| ());
let _ = reply.send(result);
}
Command::GrowZeroed(count, reply) => {
let result = unsafe { mem.grow_zeroed(count).map(|_| ()) };
let _ = reply.send(result);
}
Command::GrowAssumed(count, reply) => {
let result = unsafe { mem.grow_assumed(count).map(|_| ()) };
let _ = reply.send(result);
}
Command::Shrink(count, reply) => {
let result = mem.shrink(count);
let _ = reply.send(result);
}
Command::Shutdown(reply) => {
drop(mem);
let _ = reply.send(());
return;
}
}
}
})
}
impl<T: Copy + Clone + Send + 'static> AsyncFileMem<T> {
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let mem = FileMapped::<T>::from_path(path)?;
Ok(Self::from_file_mapped(mem))
}
pub fn temp() -> Result<Self, Error> {
let file = tempfile::tempfile()?;
let mem = FileMapped::<T>::new(file)?;
Ok(Self::from_file_mapped(mem))
}
pub fn from_file(file: File) -> Result<Self, Error> {
let mem = FileMapped::<T>::new(file)?;
Ok(Self::from_file_mapped(mem))
}
fn from_file_mapped(mem: FileMapped<T>) -> Self {
let (tx, rx) = mpsc::channel();
let thread = spawn_io_thread(mem, rx);
Self { tx: Some(tx), thread: Some(thread), _marker: PhantomData }
}
async fn send_command<R: Send + 'static>(
&self,
make_cmd: impl FnOnce(tokio::sync::oneshot::Sender<R>) -> Command<T>,
) -> Result<R, Error> {
let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
let cmd = make_cmd(reply_tx);
self.tx.as_ref().expect("AsyncFileMem already shut down").send(cmd).map_err(|_| {
io::Error::new(io::ErrorKind::BrokenPipe, "I/O thread terminated unexpectedly")
})?;
reply_rx.await.map_err(|_| {
Error::from(io::Error::new(
io::ErrorKind::BrokenPipe,
"I/O thread dropped reply channel",
))
})
}
pub async fn len(&self) -> Result<usize, Error> {
self.send_command(Command::Len).await
}
pub async fn is_empty(&self) -> Result<bool, Error> {
Ok(self.len().await? == 0)
}
pub async fn get(&self, index: usize) -> Result<Option<T>, Error> {
self.send_command(|tx| Command::Get(index, tx)).await
}
pub async fn set(&self, index: usize, value: T) -> Result<bool, Error> {
self.send_command(|tx| Command::Set(index, value, tx)).await
}
pub async fn read_slice(&self, offset: usize, count: usize) -> Result<Vec<T>, Error> {
self.send_command(|tx| Command::ReadSlice(offset, count, tx)).await?
}
pub async fn write_slice(&self, offset: usize, values: Vec<T>) -> Result<(), Error> {
self.send_command(|tx| Command::WriteSlice(offset, values, tx)).await?
}
pub async fn grow_filled(&self, count: usize, value: T) -> Result<(), Error> {
self.send_command(|tx| Command::GrowFilled(count, value, tx)).await?
}
pub async unsafe fn grow_zeroed(&self, count: usize) -> Result<(), Error> {
self.send_command(|tx| Command::GrowZeroed(count, tx)).await?
}
pub async unsafe fn grow_assumed(&self, count: usize) -> Result<(), Error> {
self.send_command(|tx| Command::GrowAssumed(count, tx)).await?
}
pub async fn shrink(&self, count: usize) -> Result<(), Error> {
self.send_command(|tx| Command::Shrink(count, tx)).await?
}
pub async fn shutdown(&mut self) -> Result<(), Error> {
if let Some(tx) = self.tx.take() {
let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
let _ = tx.send(Command::Shutdown(reply_tx));
let _ = reply_rx.await;
}
if let Some(thread) = self.thread.take() {
thread.join().map_err(|_| io::Error::other("I/O thread panicked"))?;
}
Ok(())
}
}
impl<T: Copy + Send + 'static> Drop for AsyncFileMem<T> {
fn drop(&mut self) {
drop(self.tx.take());
if let Some(thread) = self.thread.take() {
let _ = thread.join();
}
}
}
impl<T: Copy + Send + 'static> fmt::Debug for AsyncFileMem<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncFileMem").field("active", &self.tx.is_some()).finish()
}
}