embedded_async_helpers/
fair_share.rs

1//! A mutex which internally uses an intrusive linked list to store nodes in the waiting `Futures`.
2//! Read/write access is guaranteed fair, with FIFO semantics.
3//!
4//! **WARNING:** Don't `mem::forget` the access future, or there will be dragons.
5
6use core::{
7    cell::UnsafeCell,
8    future::Future,
9    ops::{Deref, DerefMut},
10    pin::Pin,
11    ptr::NonNull,
12    task::{Context, Poll, Waker},
13};
14
15// Wrapper type of queue placement.
16#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
17struct FairSharePlace(usize);
18
19struct FairShareManagement {
20    idx_in: FairSharePlace,
21    idx_out: FairSharePlace,
22    queue: IntrusiveSinglyLinkedList,
23}
24
25struct IntrusiveWakerNode {
26    waker: Waker,
27    place: FairSharePlace,
28    next: Option<NonNull<IntrusiveWakerNode>>,
29}
30
31impl IntrusiveWakerNode {
32    const fn new(waker: Waker) -> Self {
33        IntrusiveWakerNode {
34            waker,
35            place: FairSharePlace(0),
36            next: None,
37        }
38    }
39}
40
41struct IntrusiveSinglyLinkedList {
42    head: Option<NonNull<IntrusiveWakerNode>>,
43    tail: Option<NonNull<IntrusiveWakerNode>>,
44}
45
46unsafe impl Send for IntrusiveSinglyLinkedList {}
47
48impl IntrusiveSinglyLinkedList {
49    const fn new() -> Self {
50        IntrusiveSinglyLinkedList {
51            head: None,
52            tail: None,
53        }
54    }
55
56    fn is_empty(&self) -> bool {
57        self.head.is_none()
58    }
59
60    fn peek(&self) -> Option<&IntrusiveWakerNode> {
61        self.head.map(|v| unsafe { v.as_ref() })
62    }
63
64    fn push_back(&mut self, val: NonNull<IntrusiveWakerNode>) {
65        if self.head.is_some() {
66            if let Some(mut tail) = self.tail {
67                unsafe { tail.as_mut() }.next = Some(val);
68                self.tail = Some(val);
69            } else {
70                // If head is some, tail is some
71                unreachable!()
72            }
73        } else {
74            self.head = Some(val);
75            self.tail = self.head;
76        }
77    }
78
79    fn pop_head(&mut self) -> Option<NonNull<IntrusiveWakerNode>> {
80        if let Some(head) = self.head {
81            if self.head != self.tail {
82                let ret = self.head;
83                self.head = unsafe { head.as_ref() }.next;
84                return ret;
85            } else {
86                // The list is empty
87                self.head = None;
88                self.tail = None;
89            }
90        }
91
92        None
93    }
94
95    // Return the index if it was first in queue to update the counters
96    fn pop_idx(&mut self, idx: FairSharePlace) -> Option<FairSharePlace> {
97        // 1. Check if head should be replaced
98        if let Some(head) = self.head {
99            let head = unsafe { head.as_ref() };
100
101            if head.place == idx {
102                if self.head != self.tail {
103                    // There are more than 1 element in the list
104                    self.head = head.next;
105                } else {
106                    // There is only 1 element in the list
107                    self.head = None;
108                    self.tail = None;
109                }
110
111                return Some(head.place);
112            }
113        } else {
114            // The list is empty
115            return None;
116        }
117
118        // 2. It was not the first element, search for it
119        let mut head = self.head;
120
121        while let Some(mut h) = head {
122            let h = unsafe { h.as_mut() };
123
124            // Check if the next node is the one that should be removed
125            if let Some(next) = h.next {
126                let next = unsafe { next.as_ref() };
127                if next.place == idx {
128                    // Replace with what's after next
129                    h.next = next.next;
130                    break;
131                }
132            }
133
134            head = h.next;
135        }
136
137        None
138    }
139
140    #[allow(unused)]
141    fn print(&self) {
142        let mut head = self.head;
143
144        while let Some(h) = head {
145            let h2 = unsafe { h.as_ref() };
146
147            head = h2.next;
148        }
149    }
150}
151
152impl FairShareManagement {
153    fn enqueue(&mut self, node: &mut IntrusiveWakerNode) -> FairSharePlace {
154        let current = self.idx_in;
155        self.idx_in = FairSharePlace(current.0.wrapping_add(1));
156
157        node.place = current;
158
159        self.queue.push_back(node.into());
160
161        current
162    }
163
164    fn dequeue(&mut self) {
165        self.queue.pop_head();
166    }
167
168    fn wake_next_in_queue(&mut self) -> Option<Waker> {
169        if let Some(node) = self.queue.peek() {
170            self.idx_out = node.place;
171
172            Some(node.waker.clone())
173        } else {
174            self.idx_out = self.idx_in;
175
176            None
177        }
178    }
179
180    fn try_direct_access(&mut self) -> bool {
181        if self.queue.is_empty() && self.idx_in == self.idx_out {
182            // Update current counters to not get races
183            let current = self.idx_in;
184            self.idx_in = FairSharePlace(current.0.wrapping_add(1));
185
186            true
187        } else {
188            false
189        }
190    }
191}
192
193/// Async fair sharing of an underlying value.
194pub struct FairShare<T> {
195    /// Holds the underying type, this can only safely be accessed from `FairShareExclusiveAccess`.
196    storage: UnsafeCell<T>,
197    /// Holds queue handling, this is guarded with critical section tokens.
198    management: UnsafeCell<FairShareManagement>,
199}
200
201unsafe impl<T> Sync for FairShare<T> {}
202
203impl<T> FairShare<T> {
204    /// Create a new fair share, generally place this in static storage and pass around references.
205    pub const fn new(val: T) -> Self {
206        FairShare {
207            storage: UnsafeCell::new(val),
208            management: UnsafeCell::new(FairShareManagement {
209                idx_in: FairSharePlace(0),
210                idx_out: FairSharePlace(0),
211                // queue_head: None,
212                // queue_tail: None,
213                queue: IntrusiveSinglyLinkedList::new(),
214            }),
215        }
216    }
217
218    fn get_management<'a>(
219        &self,
220        _token: &'a mut critical_section::CriticalSection,
221    ) -> &'a mut FairShareManagement {
222        // Safety: Get the underlying storage if we are in a critical section
223        unsafe { &mut *(self.management.get()) }
224    }
225
226    /// Request access, await the returned future to be woken when its available.
227    pub fn access<'a>(&'a self) -> FairShareAccessFuture<'a, T> {
228        FairShareAccessFuture {
229            fs: self,
230            node: UnsafeCell::new(None),
231            place: None,
232        }
233    }
234}
235
236/// Access future.
237pub struct FairShareAccessFuture<'a, T> {
238    fs: &'a FairShare<T>,
239    node: UnsafeCell<Option<IntrusiveWakerNode>>,
240    place: Option<FairSharePlace>,
241}
242
243impl<'a, T> Drop for FairShareAccessFuture<'a, T> {
244    fn drop(&mut self) {
245        if let Some(place) = self.place {
246            critical_section::with(|mut token| {
247                let fs = self.fs.get_management(&mut token);
248
249                // Remove this from the queue
250                fs.queue.pop_idx(place);
251            });
252        }
253    }
254}
255
256impl<'a, T> Future for FairShareAccessFuture<'a, T> {
257    type Output = FairShareExclusiveAccess<'a, T>;
258
259    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
260        critical_section::with(|mut token| {
261            let fs = self.fs.get_management(&mut token);
262
263            if let Some(place) = self.place {
264                if fs.idx_out == place {
265                    // Our turn
266                    fs.dequeue();
267                    self.place = None;
268                    Poll::Ready(FairShareExclusiveAccess { fs: self.fs })
269                } else {
270                    // Continue waiting
271                    Poll::Pending
272                }
273            } else {
274                // Check if the queue is empty, then we don't need to wait
275                if fs.try_direct_access() {
276                    Poll::Ready(FairShareExclusiveAccess { fs: self.fs })
277                } else {
278                    let node = self
279                        .node
280                        .get_mut()
281                        .insert(IntrusiveWakerNode::new(cx.waker().clone()));
282
283                    // We are not in the queue yet, enqueue our waker
284                    self.place = Some(fs.enqueue(node));
285
286                    Poll::Pending
287                }
288            }
289        })
290    }
291}
292
293/// Excluseive access to the underlying storage until released or dropped.
294pub struct FairShareExclusiveAccess<'a, T> {
295    fs: &'a FairShare<T>,
296}
297
298impl<'a, T> Deref for FairShareExclusiveAccess<'a, T> {
299    type Target = T;
300
301    fn deref(&self) -> &Self::Target {
302        // Safety: We can generate mulitple immutable references to the underlying type.
303        // And if any mutable reference is generated we are protected via `&self`.
304        unsafe { &*(self.fs.storage.get()) }
305    }
306}
307
308impl<'a, T> DerefMut for FairShareExclusiveAccess<'a, T> {
309    fn deref_mut(&mut self) -> &mut Self::Target {
310        // Safety: We can generate a single mutable references to the underlying type.
311        // And if any immutable reference is generated we are protected via `&mut self`.
312        unsafe { &mut *(self.fs.storage.get()) }
313    }
314}
315
316impl<T> FairShareExclusiveAccess<'_, T> {
317    /// Release exclusive access, equates to a drop.
318    pub fn release(self) {
319        // Run drop
320    }
321}
322
323impl<T> Drop for FairShareExclusiveAccess<'_, T> {
324    fn drop(&mut self) {
325        let waker = critical_section::with(|mut token| {
326            self.fs.get_management(&mut token).wake_next_in_queue()
327        });
328
329        // Run the waker outside of the critical section to minimize its size
330        if let Some(waker) = waker {
331            waker.wake();
332        }
333    }
334}