use futures_channel::oneshot;
use futures_util::task::noop_waker;
use scoped_tls::scoped_thread_local;
use std::{
cell::{Cell, RefCell},
future::Future,
pin::Pin,
rc::Rc,
task::{Context, Poll},
};
use web_sys::{
js_sys::Function,
wasm_bindgen::{closure::Closure, JsCast},
IdbRequest, IdbTransaction,
};
#[derive(Clone)]
struct State {
transaction: IdbTransaction,
inflight_requests: Rc<Cell<usize>>,
future: Rc<RefCell<dyn 'static + Future<Output = Result<(), ()>>>>,
}
scoped_thread_local!(static CURRENT: State);
thread_local!(pub(crate) static POLLED_FORBIDDEN_THING: Cell<bool> = Cell::new(false));
fn poll_it(state: &State) {
CURRENT.set(&state, || {
let mut transaction_fut = state.future.borrow_mut();
let transaction_fut = unsafe {
Pin::new_unchecked(&mut *transaction_fut)
};
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
transaction_fut.poll(&mut Context::from_waker(&noop_waker()))
}));
let res = match res {
Ok(res) => res,
Err(err) => {
let _ = state.transaction.abort();
std::panic::resume_unwind(err);
}
};
match res {
Poll::Pending => {
if state.inflight_requests.get() == 0 {
let _ = state.transaction.abort();
POLLED_FORBIDDEN_THING.set(true);
panic!("Transaction blocked without any request under way");
}
}
Poll::Ready(Ok(())) => {
}
Poll::Ready(Err(())) => {
let _ = state.transaction.abort();
}
}
});
}
fn send_or_abort<T>(transaction: &IdbTransaction, tx: oneshot::Sender<T>, value: T) {
if tx.send(value).is_err() {
let _ = transaction.abort();
}
}
pub fn run<Fut>(transaction: IdbTransaction, transaction_contents: Fut)
where
Fut: 'static + Future<Output = Result<(), ()>>,
{
let state = State {
transaction,
inflight_requests: Rc::new(Cell::new(0)),
future: Rc::new(RefCell::new(transaction_contents)),
};
poll_it(&state as _);
}
pub fn add_request(
req: IdbRequest,
success_tx: oneshot::Sender<web_sys::Event>,
error_tx: oneshot::Sender<web_sys::Event>,
) -> impl Sized {
CURRENT.with(move |state| {
state
.inflight_requests
.set(state.inflight_requests.get() + 1);
let on_success = Closure::once({
let state = state.clone();
move |evt: web_sys::Event| {
state
.inflight_requests
.set(state.inflight_requests.get() - 1);
send_or_abort(&state.transaction, success_tx, evt);
poll_it(&state);
}
});
let on_error = Closure::once({
let state = state.clone();
move |evt: web_sys::Event| {
state
.inflight_requests
.set(state.inflight_requests.get() - 1);
send_or_abort(&state.transaction, error_tx, evt.clone());
poll_it(&state);
evt.prevent_default();
}
});
req.set_onsuccess(Some(&on_success.as_ref().dyn_ref::<Function>().unwrap()));
req.set_onerror(Some(&on_error.as_ref().dyn_ref::<Function>().unwrap()));
(on_success, on_error)
})
}