#![allow(dead_code)]
#![allow(unused_variables)]
use std::{future::Future, pin::Pin, rc::Rc};
use bitwarden_error::bitwarden_error;
use thiserror::Error;
#[cfg(not(target_arch = "wasm32"))]
use tokio::task::spawn_local;
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
use wasm_bindgen_futures::spawn_local;
type CallFunction<ThreadState> =
Box<dyn FnOnce(Rc<ThreadState>) -> Pin<Box<dyn Future<Output = ()>>> + Send + Sync>;
struct CallRequest<ThreadState> {
function: CallFunction<ThreadState>,
}
#[derive(Debug, Error)]
#[error("The call failed before it could return a value: {0}")]
#[bitwarden_error(basic)]
pub struct CallError(String);
pub struct ThreadBoundRunner<ThreadState> {
call_channel_tx: tokio::sync::mpsc::Sender<CallRequest<ThreadState>>,
}
impl<ThreadState> Clone for ThreadBoundRunner<ThreadState> {
fn clone(&self) -> Self {
ThreadBoundRunner {
call_channel_tx: self.call_channel_tx.clone(),
}
}
}
impl<ThreadState> ThreadBoundRunner<ThreadState>
where
ThreadState: 'static,
{
#[allow(missing_docs)]
pub fn new(state: ThreadState) -> Self {
let (call_channel_tx, mut call_channel_rx) =
tokio::sync::mpsc::channel::<CallRequest<ThreadState>>(1);
spawn_local(async move {
let state = Rc::new(state);
while let Some(request) = call_channel_rx.recv().await {
spawn_local((request.function)(state.clone()));
}
});
ThreadBoundRunner { call_channel_tx }
}
pub async fn run_in_thread<F, Fut, Output>(&self, function: F) -> Result<Output, CallError>
where
F: FnOnce(Rc<ThreadState>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Output>,
Output: Send + Sync + 'static,
{
let (return_channel_tx, return_channel_rx) = tokio::sync::oneshot::channel();
let request = CallRequest {
function: Box::new(|state| {
Box::pin(async move {
let result = function(state);
return_channel_tx.send(result.await).unwrap_or_else(|_| {
tracing::warn!(
"ThreadBoundDispatcher failed to send result back to the caller"
);
});
})
}),
};
self.call_channel_tx
.send(request)
.await
.expect("Call channel should not be able to close while anything still still has a reference to this object");
return_channel_rx
.await
.map_err(|e| CallError(e.to_string()))
}
}
#[cfg(test)]
mod test {
use super::*;
async fn run_test<F>(test: F) -> F::Output
where
F: std::future::Future,
{
#[cfg(not(target_arch = "wasm32"))]
{
let local_set = tokio::task::LocalSet::new();
local_set.run_until(test).await
}
#[cfg(target_arch = "wasm32")]
{
test.await
}
}
async fn run_in_another_thread<F>(test: F)
where
F: std::future::Future + Send + 'static,
F::Output: Send,
{
#[cfg(not(target_arch = "wasm32"))]
{
tokio::spawn(test).await.expect("Thread panicked");
}
#[cfg(target_arch = "wasm32")]
{
test.await;
}
}
#[derive(Default)]
struct State {
_un_send_marker: std::marker::PhantomData<*const ()>,
}
impl State {
pub fn add(&self, input: (i32, i32)) -> i32 {
input.0 + input.1
}
#[allow(clippy::unused_async)]
pub async fn async_add(&self, input: (i32, i32)) -> i32 {
input.0 + input.1
}
}
#[tokio::test]
async fn calls_function_and_returns_value() {
run_test(async {
let runner = ThreadBoundRunner::new(State::default());
let result = runner
.run_in_thread(|state| async move {
let input = (1, 2);
state.add(input)
})
.await
.expect("Calling function failed");
assert_eq!(result, 3);
})
.await;
}
#[tokio::test]
async fn calls_async_function_and_returns_value() {
run_test(async {
let runner = ThreadBoundRunner::new(State::default());
let result = runner
.run_in_thread(|state| async move {
let input = (1, 2);
state.async_add(input).await
})
.await
.expect("Calling function failed");
assert_eq!(result, 3);
})
.await;
}
#[tokio::test]
async fn can_continue_running_if_a_call_panics() {
run_test(async {
let runner = ThreadBoundRunner::new(State::default());
runner
.run_in_thread::<_, _, ()>(|state| async move {
panic!("This is a test panic");
})
.await
.expect_err("Calling function should have panicked");
let result = runner
.run_in_thread(|state| async move {
let input = (1, 2);
state.async_add(input).await
})
.await
.expect("Calling function failed");
assert_eq!(result, 3);
})
.await;
}
}