bare_sync/
watch.rs

1//! A synchronization primitive for passing the latest value to **multiple** receivers.
2
3use core::cell::{Cell, UnsafeCell};
4use core::marker::PhantomData;
5use core::mem::MaybeUninit;
6use core::ops::{Deref, DerefMut};
7
8use embassy_sync::blocking_mutex::raw::RawMutex;
9use embassy_sync::blocking_mutex::Mutex;
10
11/// The `Watch` is a single-slot signaling primitive that allows _multiple_ (`N`) receivers to get
12/// changes to the value. Unlike a [`Signal`](crate::signal::Signal), `Watch` supports multiple receivers,
13/// and unlike a [`PubSubChannel`](embassy_sync::pubsub::PubSubChannel), `Watch` immediately overwrites the previous
14/// value when a new one is sent, without waiting for all receivers to read the previous value.
15///
16/// This makes `Watch` particularly useful when a single task updates a value or "state", and multiple other tasks
17/// need to be notified about changes to this value asynchronously. Receivers may "lose" stale values, as they are
18/// always provided with the latest value.
19/// ```
20///
21/// use embedded_sync::watch::Watch;
22/// use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
23///
24/// static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
25///
26/// // Obtain receivers and sender
27/// let mut rcv0 = WATCH.receiver();
28/// let mut rcv1 = WATCH.receiver();
29/// let mut snd = WATCH.sender();
30///
31/// snd.send(10);
32///
33/// // Receive the new value (async or try)
34/// assert_eq!(rcv0.try_changed(), Some(10));
35/// assert_eq!(rcv1.try_changed(), Some(10));
36///
37/// // No update
38/// assert_eq!(rcv0.try_changed(), None);
39/// assert_eq!(rcv1.try_changed(), None);
40///
41/// snd.send(20);
42///
43/// // Using `get` marks the value as seen
44/// assert_eq!(rcv1.try_get(), Some(20));
45/// assert_eq!(rcv1.try_changed(), None);
46///
47/// snd.send(20);
48///
49/// assert_eq!(rcv1.try_get(), Some(20));
50/// assert_eq!(rcv1.try_get(), Some(20));
51///
52/// ```
53#[derive(Debug)]
54pub struct Watch<M: RawMutex, T: Clone> {
55    mutex: Mutex<M, WatchState<T>>,
56}
57
58#[derive(Debug)]
59struct WatchState<T: Clone> {
60    data: UnsafeCell<MaybeUninit<T>>,
61    current_id: Cell<u8>,
62}
63
64trait SealedWatchBehavior<T> {
65    /// Tries to retrieve the value of the `Watch` if it has changed, marking it as seen.
66    fn try_changed(&self, id: &mut u8) -> Option<T>;
67
68    /// Clears the value of the `Watch`.
69    fn clear(&self);
70
71    /// Sends a new value to the `Watch`.
72    fn send(&self, val: T);
73}
74
75/// A trait representing the 'inner' behavior of the `Watch`.
76#[allow(private_bounds)]
77pub trait WatchBehavior<T: Clone>: SealedWatchBehavior<T> {
78    /// Tries to get the value of the `Watch`, marking it as seen, if an id is given.
79    fn try_get(&self, id: Option<&mut u8>) -> Option<T>;
80
81    /// Checks if the `Watch` is been initialized with a value.
82    fn contains_value(&self) -> bool;
83}
84
85impl<M: RawMutex, T: Clone> SealedWatchBehavior<T> for Watch<M, T> {
86    fn try_changed(&self, id: &mut u8) -> Option<T> {
87        self.mutex.lock(|state| {
88            let current_id = state.current_id.get();
89            if current_id != *id {
90                *id = current_id;
91                let data = unsafe { state.data.get().read().assume_init() };
92                Some(data)
93            } else {
94                None
95            }
96        })
97    }
98
99    fn clear(&self) {
100        self.mutex.lock(|state| {
101            state.current_id.set(0);
102        })
103    }
104
105    fn send(&self, val: T) {
106        self.mutex.lock(|state| {
107            unsafe { state.data.get().write(MaybeUninit::new(val)) };
108            let mut new_id = state.current_id.get().wrapping_add(1);
109            if new_id == 0 {
110                new_id = 1;
111            }
112            state.current_id.set(new_id);
113        })
114    }
115}
116
117impl<M: RawMutex, T: Clone> WatchBehavior<T> for Watch<M, T> {
118    fn try_get(&self, id: Option<&mut u8>) -> Option<T> {
119        self.mutex.lock(|state| {
120            let current_id = state.current_id.get();
121            if let Some(id) = id {
122                *id = current_id;
123            }
124            if current_id == 0 {
125                None
126            } else {
127                let data = unsafe { state.data.get().read().assume_init() };
128                Some(data)
129            }
130        })
131    }
132
133    fn contains_value(&self) -> bool {
134        self.mutex.lock(|state| state.current_id.get() != 0)
135    }
136}
137
138impl<M: RawMutex, T: Clone> Watch<M, T> {
139    /// Create a new `Watch` channel for `N` receivers.
140    pub const fn new() -> Self {
141        Self {
142            mutex: Mutex::new(WatchState {
143                data: UnsafeCell::new(MaybeUninit::zeroed()),
144                current_id: Cell::new(0),
145            }),
146        }
147    }
148
149    /// Create a new `Watch` channel with default data.
150    pub const fn new_with(data: T) -> Self {
151        Self {
152            mutex: Mutex::new(WatchState {
153                data: UnsafeCell::new(MaybeUninit::new(data)),
154                current_id: Cell::new(0),
155            }),
156        }
157    }
158
159    /// Create a new [`Sender`] for the `Watch`.
160    pub fn sender(&self) -> Sender<'_, M, T> {
161        Sender(Snd::new(self))
162    }
163
164    /// Try to create a new [`Receiver`] for the `Watch`. If the
165    /// maximum number of receivers has been reached, `None` is returned.
166    pub fn receiver(&self) -> Receiver<'_, M, T> {
167        Receiver(Rcv::new(self))
168    }
169
170    /// Returns the message ID of the latest message sent to the `Watch`.
171    ///
172    /// This counter is monotonic, and is incremented every time a new message is sent.
173    pub fn get_msg_id(&self) -> u8 {
174        self.mutex.lock(|state| state.current_id.get())
175    }
176
177    /// Tries to get the value of the `Watch`.
178    pub fn try_get(&self) -> Option<T> {
179        WatchBehavior::try_get(self, None)
180    }
181}
182
183/// A receiver can `.await` a change in the `Watch` value.
184#[derive(Debug)]
185pub struct Snd<'a, T: Clone, W: WatchBehavior<T> + ?Sized> {
186    watch: &'a W,
187    _phantom: PhantomData<T>,
188}
189
190impl<'a, T: Clone, W: WatchBehavior<T> + ?Sized> Clone for Snd<'a, T, W> {
191    fn clone(&self) -> Self {
192        Self {
193            watch: self.watch,
194            _phantom: PhantomData,
195        }
196    }
197}
198
199impl<'a, T: Clone, W: WatchBehavior<T> + ?Sized> Snd<'a, T, W> {
200    /// Creates a new `Receiver` with a reference to the `Watch`.
201    fn new(watch: &'a W) -> Self {
202        Self {
203            watch,
204            _phantom: PhantomData,
205        }
206    }
207
208    /// Sends a new value to the `Watch`.
209    pub fn send(&self, val: T) {
210        self.watch.send(val)
211    }
212
213    /// Clears the value of the `Watch`.
214    /// This will cause calls to [`Rcv::get`] to be pending.
215    pub fn clear(&self) {
216        self.watch.clear()
217    }
218
219    /// Tries to retrieve the value of the `Watch`.
220    pub fn try_get(&self) -> Option<T> {
221        self.watch.try_get(None)
222    }
223
224    /// Returns true if the `Watch` contains a value.
225    pub fn contains_value(&self) -> bool {
226        self.watch.contains_value()
227    }
228}
229
230/// A sender of a `Watch` channel.
231///
232/// For a simpler type definition, consider [`DynSender`] at the expense of
233/// some runtime performance due to dynamic dispatch.
234#[derive(Debug)]
235pub struct Sender<'a, M: RawMutex, T: Clone>(Snd<'a, T, Watch<M, T>>);
236
237impl<'a, M: RawMutex, T: Clone> Clone for Sender<'a, M, T> {
238    fn clone(&self) -> Self {
239        Self(self.0.clone())
240    }
241}
242
243impl<'a, M: RawMutex, T: Clone> Deref for Sender<'a, M, T> {
244    type Target = Snd<'a, T, Watch<M, T>>;
245
246    fn deref(&self) -> &Self::Target {
247        &self.0
248    }
249}
250
251impl<'a, M: RawMutex, T: Clone> DerefMut for Sender<'a, M, T> {
252    fn deref_mut(&mut self) -> &mut Self::Target {
253        &mut self.0
254    }
255}
256
257/// A receiver can get a change in the `Watch` value.
258pub struct Rcv<'a, T: Clone, W: WatchBehavior<T> + ?Sized> {
259    watch: &'a W,
260    at_id: u8,
261    _phantom: PhantomData<T>,
262}
263
264impl<'a, T: Clone, W: WatchBehavior<T> + ?Sized> Rcv<'a, T, W> {
265    /// Creates a new `Receiver` with a reference to the `Watch`.
266    fn new(watch: &'a W) -> Self {
267        Self {
268            watch,
269            at_id: 0,
270            _phantom: PhantomData,
271        }
272    }
273
274    /// Tries to get the current value of the `Watch` without waiting, marking it as seen.
275    pub fn try_get(&mut self) -> Option<T> {
276        self.watch.try_get(Some(&mut self.at_id))
277    }
278
279    /// Tries to get the new value of the watch without waiting, marking it as seen.
280    pub fn try_changed(&mut self) -> Option<T> {
281        self.watch.try_changed(&mut self.at_id)
282    }
283
284    /// Checks if the `Watch` contains a value. If this returns true,
285    /// then awaiting [`Rcv::get`] will return immediately.
286    pub fn contains_value(&self) -> bool {
287        self.watch.contains_value()
288    }
289}
290
291/// A receiver of a `Watch` channel.
292pub struct Receiver<'a, M: RawMutex, T: Clone>(Rcv<'a, T, Watch<M, T>>);
293
294impl<'a, M: RawMutex, T: Clone> Deref for Receiver<'a, M, T> {
295    type Target = Rcv<'a, T, Watch<M, T>>;
296
297    fn deref(&self) -> &Self::Target {
298        &self.0
299    }
300}
301
302impl<'a, M: RawMutex, T: Clone> DerefMut for Receiver<'a, M, T> {
303    fn deref_mut(&mut self) -> &mut Self::Target {
304        &mut self.0
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::Watch;
311    use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
312
313    #[test]
314    fn multiple_sends() {
315        static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
316
317        // Obtain receiver and sender
318        let mut rcv = WATCH.receiver();
319        let snd = WATCH.sender();
320
321        // Not initialized
322        assert_eq!(rcv.try_changed(), None);
323
324        // Receive another value
325        snd.send(20);
326        assert_eq!(rcv.try_changed(), Some(20));
327
328        // No update
329        assert_eq!(rcv.try_changed(), None);
330    }
331
332    #[test]
333    fn all_try_get() {
334        static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
335
336        // Obtain receiver and sender
337        let mut rcv = WATCH.receiver();
338        let snd = WATCH.sender();
339
340        // Not initialized
341        assert_eq!(WATCH.try_get(), None);
342        assert_eq!(rcv.try_get(), None);
343        assert_eq!(snd.try_get(), None);
344
345        // Receive the new value
346        snd.send(10);
347        assert_eq!(WATCH.try_get(), Some(10));
348        assert_eq!(rcv.try_get(), Some(10));
349        assert_eq!(snd.try_get(), Some(10));
350    }
351
352    #[test]
353    fn once_lock_like() {
354        static CONFIG0: u8 = 10;
355        static CONFIG1: u8 = 20;
356
357        static WATCH: Watch<CriticalSectionRawMutex, &'static u8> = Watch::new();
358
359        // Obtain receiver and sender
360        let mut rcv = WATCH.receiver();
361        let snd = WATCH.sender();
362
363        // Not initialized
364        assert_eq!(rcv.try_changed(), None);
365
366        // Receive the new value
367        snd.send(&CONFIG0);
368        let rcv0 = rcv.try_changed().unwrap();
369        assert_eq!(rcv0, &10);
370
371        // Receive another value
372        snd.send(&CONFIG1);
373        let rcv1 = rcv.try_changed();
374        assert_eq!(rcv1, Some(&20));
375
376        // No update
377        assert_eq!(rcv.try_changed(), None);
378
379        // Ensure similarity with original static
380        assert_eq!(rcv0, &CONFIG0);
381        assert_eq!(rcv1, Some(&CONFIG1));
382    }
383
384    #[test]
385    fn sender_modify() {
386        static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
387
388        // Obtain receiver and sender
389        let mut rcv = WATCH.receiver();
390        let snd = WATCH.sender();
391
392        // Receive the new value
393        snd.send(10);
394        assert_eq!(rcv.try_changed(), Some(10));
395    }
396
397    #[test]
398    fn receive_after_create() {
399        static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
400
401        // Obtain sender and send value
402        let snd = WATCH.sender();
403        snd.send(10);
404
405        // Obtain receiver and receive value
406        let mut rcv = WATCH.receiver();
407        assert_eq!(rcv.try_changed(), Some(10));
408    }
409
410    #[test]
411    fn multiple_receivers() {
412        static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
413
414        // Obtain receivers and sender
415        let mut rcv0 = WATCH.receiver();
416        let snd = WATCH.sender();
417
418        // No update for both
419        assert_eq!(rcv0.try_changed(), None);
420
421        // Send a new value
422        snd.send(0);
423
424        // Both receivers receive the new value
425        assert_eq!(rcv0.try_changed(), Some(0));
426    }
427
428    #[test]
429    fn clone_senders() {
430        // Obtain different ways to send
431        static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
432        let snd0 = WATCH.sender();
433        let snd1 = snd0.clone();
434
435        // Obtain Receiver
436        let mut rcv = WATCH.receiver();
437
438        // Send a value from first sender
439        snd0.send(10);
440        assert_eq!(rcv.try_changed(), Some(10));
441
442        // Send a value from second sender
443        snd1.send(20);
444        assert_eq!(rcv.try_changed(), Some(20));
445    }
446
447    #[test]
448    fn contains_value() {
449        static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
450
451        // Obtain receiver and sender
452        let rcv = WATCH.receiver();
453        let snd = WATCH.sender();
454
455        // check if the watch contains a value
456        assert_eq!(rcv.contains_value(), false);
457        assert_eq!(snd.contains_value(), false);
458
459        // Send a value
460        snd.send(10);
461
462        // check if the watch contains a value
463        assert_eq!(rcv.contains_value(), true);
464        assert_eq!(snd.contains_value(), true);
465    }
466}