use {
alloc::sync::{Arc, Weak},
core::{
cell::Cell,
fmt, mem, ptr,
sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
},
};
use super::{WatchArg, WatchedMeta};
const FLAG_COUNT: usize = usize::BITS as usize;
pub(crate) struct SyncContext<'ctx, O: ?Sized> {
flag: Arc<AtomicUsize>,
watched: [WatchedMeta<'ctx, O>; FLAG_COUNT],
next_index: Cell<usize>,
}
impl<'ctx, O: ?Sized> SyncContext<'ctx, O> {
pub fn new() -> Self {
Self {
flag: Arc::default(),
watched: [0; FLAG_COUNT].map(|_| WatchedMeta::new()),
next_index: Cell::new(0),
}
}
pub fn check_for_updates(&self) {
let set_bits = self.flag.swap(0, Ordering::Acquire);
for i in 0..FLAG_COUNT {
if (set_bits & (1 << i)) != 0 {
self.watched[i].trigger_external();
}
}
}
}
struct FlagPole {
ptr: AtomicPtr<AtomicUsize>,
}
impl Drop for FlagPole {
fn drop(&mut self) {
let flag_ptr: *mut AtomicUsize = *self.ptr.get_mut();
if !flag_ptr.is_null() {
unsafe {
Weak::from_raw(flag_ptr);
}
}
}
}
impl Default for FlagPole {
fn default() -> Self {
Self {
ptr: AtomicPtr::new(ptr::null_mut()),
}
}
}
impl FlagPole {
fn set(&self, value: Weak<AtomicUsize>) {
let flag_ptr = value.into_raw() as *mut AtomicUsize;
if self
.ptr
.compare_exchange(
ptr::null_mut(),
flag_ptr,
Ordering::Release,
Ordering::Relaxed,
)
.is_err()
{
unsafe {
Weak::from_raw(flag_ptr);
}
}
}
fn get(&self) -> Weak<AtomicUsize> {
let flag_ptr = self.ptr.load(Ordering::Acquire);
if flag_ptr.is_null() {
Weak::new()
} else {
let current = unsafe { Weak::from_raw(flag_ptr) };
mem::forget(Weak::clone(¤t));
current
}
}
}
#[derive(Default)]
struct SharedMeta {
flag_pole: FlagPole,
mask: AtomicUsize,
}
pub struct SyncWatchedMeta {
data: Arc<SharedMeta>,
index: Cell<usize>,
}
impl Default for SyncWatchedMeta {
fn default() -> Self {
Self {
data: Arc::default(),
index: Cell::new(usize::MAX),
}
}
}
impl fmt::Debug for SyncWatchedMeta {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(SyncWatchedMeta)")
}
}
impl SyncWatchedMeta {
pub fn new() -> Self {
Self::default()
}
pub fn watched<'ctx, O: ?Sized>(&self, ctx: WatchArg<'_, 'ctx, O>) {
if let Some(sctx) = ctx.frame_info.sync_context.upgrade() {
if self.index.get() == usize::MAX {
let index = sctx.next_index.get();
sctx.next_index.set(index + 1 % FLAG_COUNT);
let mask = 1 << index;
let weak_flag = Arc::downgrade(&sctx.flag);
self.data.mask.store(mask, Ordering::Relaxed);
self.data.flag_pole.set(weak_flag);
self.index.set(index);
}
sctx.watched[self.index.get()].watched(ctx);
}
}
pub fn create_trigger(&self) -> SyncTrigger {
SyncTrigger {
data: Arc::downgrade(&self.data),
}
}
}
#[derive(Clone)]
pub struct SyncTrigger {
data: Weak<SharedMeta>,
}
impl fmt::Debug for SyncTrigger {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(SyncTrigger)")
}
}
impl SyncTrigger {
pub fn new_inert() -> Self {
Self { data: Weak::new() }
}
pub fn trigger(&self) {
if let Some(data) = self.data.upgrade() {
if let Some(flag) = data.flag_pole.get().upgrade() {
let mask = data.mask.load(Ordering::Relaxed);
flag.fetch_or(mask, Ordering::Release);
}
}
}
}
pub fn watched_channel<S, R>(
pair: (S, R),
) -> (WatchedSender<S>, WatchedReceiver<R>) {
let (sender, receiver) = pair;
let meta = SyncWatchedMeta::new();
let trigger = meta.create_trigger();
(
WatchedSender { sender, trigger },
WatchedReceiver { receiver, meta },
)
}
#[derive(Clone, Debug)]
pub struct WatchedSender<S: ?Sized> {
trigger: SyncTrigger,
sender: S,
}
impl<S: ?Sized> Drop for WatchedSender<S> {
fn drop(&mut self) {
self.trigger.trigger();
}
}
impl<S: ?Sized> WatchedSender<S> {
pub fn sender(&self) -> SendGuard<'_, S> {
SendGuard { origin: self }
}
pub fn trigger_receiver(&self) {
self.trigger.trigger();
}
}
pub struct SendGuard<'a, S: ?Sized> {
origin: &'a WatchedSender<S>,
}
impl<'a, S: ?Sized> core::ops::Deref for SendGuard<'a, S> {
type Target = S;
fn deref(&self) -> &S {
&self.origin.sender
}
}
impl<'a, S: ?Sized> Drop for SendGuard<'a, S> {
fn drop(&mut self) {
self.origin.trigger.trigger();
}
}
#[derive(Debug)]
pub struct WatchedReceiver<R: ?Sized> {
meta: SyncWatchedMeta,
receiver: R,
}
impl<R: ?Sized> WatchedReceiver<R> {
pub fn get<'ctx, O: ?Sized>(&self, ctx: WatchArg<'_, 'ctx, O>) -> &R {
self.meta.watched(ctx);
&self.receiver
}
pub fn get_mut<'ctx, O: ?Sized>(
&mut self,
ctx: WatchArg<'_, 'ctx, O>,
) -> &mut R {
self.meta.watched(ctx);
&mut self.receiver
}
}