drying_paint/
sync.rs

1/* SPDX-License-Identifier: (Apache-2.0 OR MIT OR Zlib) */
2/* Copyright © 2021 Violet Leonard */
3
4use {
5    alloc::sync::{Arc, Weak},
6    core::{
7        cell::Cell,
8        fmt, mem, ptr,
9        sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
10    },
11};
12
13use crate::{trigger::WatchArg, WatchedMeta};
14
15const FLAG_COUNT: usize = usize::BITS as usize;
16
17pub(crate) struct SyncContext<'ctx, O: ?Sized> {
18    flag: Arc<AtomicUsize>,
19    watched: [WatchedMeta<'ctx, O>; FLAG_COUNT],
20    next_index: Cell<usize>,
21}
22
23impl<'ctx, O: ?Sized> SyncContext<'ctx, O> {
24    pub fn new() -> Self {
25        Self {
26            flag: Arc::default(),
27            watched: [0; FLAG_COUNT].map(|_| WatchedMeta::new()),
28            next_index: Cell::new(0),
29        }
30    }
31
32    pub fn check_for_updates(&self) {
33        let set_bits = self.flag.swap(0, Ordering::Acquire);
34        for i in 0..FLAG_COUNT {
35            if (set_bits & (1 << i)) != 0 {
36                self.watched[i].trigger_external();
37            }
38        }
39    }
40}
41
42struct FlagPole {
43    ptr: AtomicPtr<AtomicUsize>,
44}
45
46impl Drop for FlagPole {
47    fn drop(&mut self) {
48        let flag_ptr: *mut AtomicUsize = *self.ptr.get_mut();
49        if !flag_ptr.is_null() {
50            // drop one weak reference
51            unsafe {
52                Weak::from_raw(flag_ptr);
53            }
54        }
55    }
56}
57
58impl Default for FlagPole {
59    fn default() -> Self {
60        Self {
61            ptr: AtomicPtr::new(ptr::null_mut()),
62        }
63    }
64}
65
66impl FlagPole {
67    fn set(&self, value: Weak<AtomicUsize>) {
68        let flag_ptr = value.into_raw() as *mut AtomicUsize;
69        // Store the new value only if the current value is null
70        if self
71            .ptr
72            .compare_exchange(
73                ptr::null_mut(),
74                flag_ptr,
75                Ordering::Release,
76                Ordering::Relaxed,
77            )
78            .is_err()
79        {
80            // If the store failed, ensure the ref count is
81            // properly decremented
82            unsafe {
83                Weak::from_raw(flag_ptr);
84            }
85        }
86    }
87
88    fn get(&self) -> Weak<AtomicUsize> {
89        let flag_ptr = self.ptr.load(Ordering::Acquire);
90        if flag_ptr.is_null() {
91            Weak::new()
92        } else {
93            let current = unsafe { Weak::from_raw(flag_ptr) };
94            // increment one weak ref before returning, so the pointer
95            // stored in the atomic remains valid
96            mem::forget(Weak::clone(&current));
97            current
98        }
99    }
100}
101
102#[derive(Default)]
103struct SharedMeta {
104    flag_pole: FlagPole,
105    mask: AtomicUsize,
106}
107
108/// SyncWatchedMeta is like WatchedMeta, however allows you to create
109/// a trigger which may be sent to other threads.
110///
111/// When this trigger is invoked, watch functions in the single-threaded watch
112/// context will be re-run.
113pub struct SyncWatchedMeta {
114    data: Arc<SharedMeta>,
115    index: Cell<usize>,
116}
117
118impl Default for SyncWatchedMeta {
119    fn default() -> Self {
120        Self {
121            data: Arc::default(),
122            index: Cell::new(usize::MAX),
123        }
124    }
125}
126
127impl fmt::Debug for SyncWatchedMeta {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        write!(f, "(SyncWatchedMeta)")
130    }
131}
132
133impl SyncWatchedMeta {
134    /// Create a new AtomicWatchedMeta
135    pub fn new() -> Self {
136        Self::default()
137    }
138
139    /// When run in a function designed to watch a value, will bind so that
140    /// function will be re-run when a trigger associated with this
141    /// AtomicWatchedMeta is invoked.
142    pub fn watched<O: ?Sized>(&self, ctx: WatchArg<'_, '_, O>) {
143        if let Some(sctx) = ctx.frame_info.sync_context.upgrade() {
144            if self.index.get() == usize::MAX {
145                let index = sctx.next_index.get();
146                sctx.next_index.set(index + 1 % FLAG_COUNT);
147                let mask = 1 << index;
148                let weak_flag = Arc::downgrade(&sctx.flag);
149                self.data.mask.store(mask, Ordering::Relaxed);
150                self.data.flag_pole.set(weak_flag);
151                self.index.set(index);
152            }
153            sctx.watched[self.index.get()].watched(ctx);
154        }
155    }
156
157    /// Create a trigger for this AtomicWatchedMeta which may be sent to
158    /// another thread.
159    pub fn create_trigger(&self) -> SyncTrigger {
160        SyncTrigger {
161            data: Arc::downgrade(&self.data),
162        }
163    }
164}
165
166/// A type which can be used from another thread to trigger watch functions
167/// watching an AtomicWatchedMeta.
168#[derive(Clone)]
169pub struct SyncTrigger {
170    data: Weak<SharedMeta>,
171}
172
173impl fmt::Debug for SyncTrigger {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        write!(f, "(SyncTrigger)")
176    }
177}
178
179impl SyncTrigger {
180    /// Create an SyncTrigger which is not assocaited with any
181    /// SyncWatchedMeta.  Invoking the trigger returned from this function
182    /// will do nothing.  This may be useful e.g. as a placeholder value.
183    pub fn new_inert() -> Self {
184        Self { data: Weak::new() }
185    }
186
187    pub fn trigger(&self) {
188        if let Some(data) = self.data.upgrade() {
189            if let Some(flag) = data.flag_pole.get().upgrade() {
190                let mask = data.mask.load(Ordering::Relaxed);
191                flag.fetch_or(mask, Ordering::Release);
192            }
193        }
194    }
195}
196
197pub fn watched_channel<S, R>(
198    pair: (S, R),
199) -> (WatchedSender<S>, WatchedReceiver<R>) {
200    let (sender, receiver) = pair;
201    let meta = SyncWatchedMeta::new();
202    let trigger = meta.create_trigger();
203    (
204        WatchedSender { sender, trigger },
205        WatchedReceiver { receiver, meta },
206    )
207}
208
209/// The sender half of a watched channel.
210#[derive(Clone, Debug)]
211pub struct WatchedSender<S: ?Sized> {
212    trigger: SyncTrigger,
213    sender: S,
214}
215
216impl<S: ?Sized> Drop for WatchedSender<S> {
217    fn drop(&mut self) {
218        self.trigger.trigger();
219    }
220}
221
222impl<S: ?Sized> WatchedSender<S> {
223    pub fn sender(&self) -> SendGuard<'_, S> {
224        SendGuard { origin: self }
225    }
226
227    pub fn trigger_receiver(&self) {
228        self.trigger.trigger();
229    }
230}
231
232pub struct SendGuard<'a, S: ?Sized> {
233    origin: &'a WatchedSender<S>,
234}
235
236impl<'a, S: ?Sized> core::ops::Deref for SendGuard<'a, S> {
237    type Target = S;
238    fn deref(&self) -> &S {
239        &self.origin.sender
240    }
241}
242
243impl<'a, S: ?Sized> Drop for SendGuard<'a, S> {
244    fn drop(&mut self) {
245        self.origin.trigger.trigger();
246    }
247}
248
249#[derive(Debug)]
250pub struct WatchedReceiver<R: ?Sized> {
251    meta: SyncWatchedMeta,
252    receiver: R,
253}
254
255impl<R: ?Sized> WatchedReceiver<R> {
256    pub fn get<O: ?Sized>(&self, ctx: WatchArg<'_, '_, O>) -> &R {
257        self.meta.watched(ctx);
258        &self.receiver
259    }
260
261    pub fn get_mut<O: ?Sized>(&mut self, ctx: WatchArg<'_, '_, O>) -> &mut R {
262        self.meta.watched(ctx);
263        &mut self.receiver
264    }
265}