use crate::application;
use send_cells::UnsafeSendCell;
use send_cells::unsafe_sync_cell::UnsafeSyncCell;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex, MutexGuard};
#[derive(Debug)]
struct Shared<T: 'static> {
inner: Option<UnsafeSendCell<UnsafeSyncCell<T>>>,
mutex: Mutex<()>,
}
impl<T> Drop for Shared<T> {
fn drop(&mut self) {
if let Some(take) = self.inner.take() {
let drop_shared = format!("MainThreadCell::drop({})", std::any::type_name::<T>());
application::submit_to_main_thread(drop_shared, || {
drop(take);
});
}
}
}
pub struct MainThreadGuard<'a, T: 'static> {
_guard: MutexGuard<'a, ()>,
value: &'a mut T,
}
impl<'a, T> AsRef<T> for MainThreadGuard<'a, T> {
fn as_ref(&self) -> &T {
&*self.value
}
}
impl<'a, T> AsMut<T> for MainThreadGuard<'a, T> {
fn as_mut(&mut self) -> &mut T {
&mut *self.value
}
}
impl<'a, T> Deref for MainThreadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.value
}
}
impl<'a, T> DerefMut for MainThreadGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.value
}
}
impl<'a, T: Debug> Debug for MainThreadGuard<'a, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MainThreadGuard")
.field("value", &*self.value)
.finish()
}
}
pub struct MainThreadCell<T: 'static> {
shared: Option<Arc<Shared<T>>>,
}
impl<T> PartialEq for MainThreadCell<T> {
fn eq(&self, other: &Self) -> bool {
let s = self.shared.as_ref().unwrap();
let o = other.shared.as_ref().unwrap();
Arc::ptr_eq(s, o)
}
}
impl<T> Clone for MainThreadCell<T> {
fn clone(&self) -> Self {
MainThreadCell {
shared: self.shared.clone(),
}
}
}
impl<T> MainThreadCell<T> {
#[inline]
pub fn new(t: T) -> Self {
let cell = unsafe { UnsafeSendCell::new_unchecked(UnsafeSyncCell::new(t)) };
MainThreadCell {
shared: Some(Arc::new(Shared {
inner: Some(cell),
mutex: Mutex::new(()),
})),
}
}
#[inline]
fn verify_main_thread() {
assert!(
application::is_main_thread(),
"MainThreadCell accessed from non-main thread"
);
}
pub fn lock(&self) -> MainThreadGuard<'_, T> {
Self::verify_main_thread();
let guard = self.shared.as_ref().unwrap().mutex.lock().unwrap();
let value = unsafe {
let inner = self.shared.as_ref().unwrap().inner.as_ref().unwrap();
inner.get().get_mut_unchecked()
};
MainThreadGuard {
_guard: guard,
value,
}
}
pub fn assume<C, R>(&self, c: C) -> R
where
C: FnOnce(&T) -> R,
{
Self::verify_main_thread();
let guard = self.shared.as_ref().unwrap().mutex.lock().unwrap();
let r = c(unsafe {
self.shared
.as_ref()
.unwrap()
.inner
.as_ref()
.unwrap()
.get()
.get()
});
drop(guard);
r
}
pub async fn with<C, R>(&self, c: C) -> R
where
C: FnOnce(&T) -> R + Send + 'static,
R: Send + 'static,
T: 'static,
{
let shared = self.shared.clone();
let main_thread_cell = format!("MainThreadCell({})", std::any::type_name::<T>());
application::on_main_thread(main_thread_cell, move || {
Self::verify_main_thread();
let guard = shared.as_ref().unwrap().mutex.lock().unwrap();
let r = c(unsafe { shared.as_ref().unwrap().inner.as_ref().unwrap().get().get() });
drop(guard);
r
})
.await
}
pub async fn with_async<C, R, F>(&self, c: C) -> R
where
C: FnOnce(&T) -> F + Send + 'static,
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
T: 'static,
{
let shared = self.shared.clone();
let main_thread_cell = format!("MainThreadCell({})", std::any::type_name::<T>());
let future = application::on_main_thread(main_thread_cell, move || {
Self::verify_main_thread();
let guard = shared.as_ref().unwrap().mutex.lock().unwrap();
let future = c(unsafe { shared.as_ref().unwrap().inner.as_ref().unwrap().get().get() });
drop(guard);
future
})
.await;
future.await
}
pub async fn new_on_main_thread<C, F>(c: C) -> MainThreadCell<T>
where
C: FnOnce() -> F + Send + 'static,
F: Future<Output = T> + Send + 'static,
{
logwise::info_sync!("MainThreadCell::new_on_main_thread() started");
let new_on_main_thread = format!(
"MainThreadCell::new_on_main_thread({})",
std::any::type_name::<T>()
);
let value = application::on_main_thread(new_on_main_thread, || async move {
logwise::info_sync!("Inside main thread closure");
let f = c();
logwise::info_sync!("Calling provided closure f()...");
let r = f.await;
logwise::info_sync!("Closure completed, creating MainThreadCell...");
MainThreadCell::new(r)
})
.await
.await;
logwise::info_sync!("Main thread execution completed, returning value");
value
}
}
unsafe impl<T> Send for MainThreadCell<T> {}
impl<T: Debug> Debug for MainThreadCell<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MainThreadCell").finish()
}
}
impl<T: Default> Default for MainThreadCell<T> {
fn default() -> Self {
MainThreadCell::new(Default::default())
}
}
impl<T> From<T> for MainThreadCell<T> {
fn from(value: T) -> Self {
MainThreadCell::new(value)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(target_arch = "wasm32")]
use wasm_safe_thread as thread;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_cell_construction() {
let cell = MainThreadCell::new(42);
let cell_from: MainThreadCell<i32> = 42.into();
let cell_default: MainThreadCell<i32> = Default::default();
std::mem::forget(cell);
std::mem::forget(cell_from);
std::mem::forget(cell_default);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_debug_impl() {
let cell = MainThreadCell::new(42);
let debug_str = format!("{:?}", cell);
assert!(debug_str.contains("MainThreadCell"));
std::mem::forget(cell);
}
#[test_executors::async_test]
async fn test_send_across_threads() {
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
let cell = MainThreadCell::new(42);
let (c, f) = r#continue::continuation();
thread::spawn(move || {
let held_cell = cell;
c.send(());
std::mem::forget(held_cell);
});
f.await;
}
}