#![forbid(unsafe_code)]
mod channel;
mod oneshot;
type Func<T> = Box<dyn FnOnce(&mut T) + Send + 'static>;
pub struct AsyncifiedBuilder<T> {
channel_size: usize,
on_close: Option<Func<T>>,
thread_builder: Option<std::thread::Builder>,
}
impl<T: 'static> Default for AsyncifiedBuilder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: 'static> AsyncifiedBuilder<T> {
pub fn new() -> Self {
Self {
channel_size: 16,
on_close: None,
thread_builder: None,
}
}
pub fn channel_size(mut self, size: usize) -> Self {
self.channel_size = size;
self
}
pub fn thread_builder(mut self, builder: std::thread::Builder) -> Self {
self.thread_builder = Some(builder);
self
}
pub fn on_close<F>(mut self, f: F) -> Self
where
F: FnOnce(&mut T) + Send + 'static,
{
self.on_close = Some(Box::new(f));
self
}
pub async fn build_ok<F>(self, val_fn: F) -> Asyncified<T>
where
F: Send + 'static + FnOnce() -> T,
{
self.build(move || Ok::<_, ()>(val_fn()))
.await
.expect("function can't fail")
}
pub async fn build<E, F>(self, val_fn: F) -> Result<Asyncified<T>, E>
where
E: Send + 'static,
F: Send + 'static + FnOnce() -> Result<T, E>,
{
let thread_builder = self
.thread_builder
.unwrap_or_else(|| std::thread::Builder::new().name("Asyncified thread".to_string()));
let channel_size = self.channel_size;
let (tx, mut rx) = channel::new::<Func<T>>(channel_size);
let (res_tx, res_rx) = oneshot::new::<Result<(), E>>();
let on_close = self.on_close;
thread_builder
.spawn(move || {
let mut val = match val_fn() {
Ok(val) => {
res_tx.send(Ok(()));
val
}
Err(e) => {
res_tx.send(Err(e));
return;
}
};
while let Some(f) = rx.recv() {
f(&mut val)
}
if let Some(on_close) = on_close {
on_close(&mut val)
}
})
.expect("should be able to spawn new thread for Asyncified instance");
res_rx.recv().await?;
Ok(Asyncified { tx })
}
}
pub struct Asyncified<T> {
tx: channel::Sender<Func<T>>,
}
impl<T> Clone for Asyncified<T> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
}
}
}
impl<T> std::fmt::Debug for Asyncified<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Asyncified")
.field("tx", &"<channel::Sender>")
.finish()
}
}
impl<T: 'static> Asyncified<T> {
pub async fn new<E, F>(val_fn: F) -> Result<Asyncified<T>, E>
where
E: Send + 'static,
F: Send + 'static + FnOnce() -> Result<T, E>,
{
Self::builder().build(val_fn).await
}
pub fn builder() -> AsyncifiedBuilder<T> {
AsyncifiedBuilder::new()
}
pub async fn call<R: Send + 'static, F: (FnOnce(&mut T) -> R) + Send + 'static>(
&self,
f: F,
) -> R {
let (tx, rx) = oneshot::new::<R>();
let _ = self
.tx
.send(Box::new(move |item| {
let res = f(item);
tx.send(res);
}))
.await;
rx.recv().await
}
}
#[cfg(test)]
mod test {
use super::*;
use std::time::{Duration, Instant};
#[test]
fn new_doesnt_block() {
let start = Instant::now();
let _fut = Asyncified::new(|| {
std::thread::sleep(Duration::from_secs(10));
Ok::<_, ()>(())
});
let d = Instant::now().duration_since(start).as_millis();
assert!(d < 100);
}
#[tokio::test]
async fn call_doesnt_block() {
let a = Asyncified::new(|| Ok::<_, ()>(())).await.unwrap();
let start = Instant::now();
let _fut = a.call(|_| {
std::thread::sleep(Duration::from_secs(10));
});
let d = Instant::now().duration_since(start).as_millis();
assert!(d < 100);
}
#[tokio::test]
async fn basic_updating_works() {
let a = Asyncified::new(|| Ok::<_, ()>(0u64)).await.unwrap();
for i in 1..100_000 {
assert_eq!(
a.call(|n| {
*n += 1;
*n
})
.await,
i
);
}
}
#[tokio::test]
async fn parallel_updating_works() {
let a = Asyncified::new(|| Ok::<_, ()>(0u64)).await.unwrap();
let handles: Vec<_> = (0..10)
.map({
let a = a.clone();
move |_| {
let a = a.clone();
tokio::spawn(async move {
for _ in 0..10_000 {
a.call(|n| {
*n += 1;
*n
})
.await;
}
})
}
})
.collect();
for handle in handles {
let _ = handle.await;
}
assert_eq!(a.call(|n| *n).await, 100_000);
}
#[tokio::test]
async fn on_close_is_called() {
let (tx, rx) = tokio::sync::oneshot::channel();
let a = Asyncified::builder()
.on_close(move |val| {
let _ = tx.send(*val);
})
.build_ok(|| 0u64)
.await;
a.call(|v| *v = 100).await;
drop(a);
assert_eq!(rx.await.unwrap(), 100);
}
}