embedded_threads/
lock.rs

1use core::cell::UnsafeCell;
2
3use super::arch::interrupt;
4use super::threadlist::ThreadList;
5use super::{ThreadState, Threads};
6
7pub struct Lock {
8    state: UnsafeCell<LockState>,
9}
10
11unsafe impl Sync for Lock {}
12
13enum LockState {
14    Unlocked,
15    Locked(ThreadList),
16}
17
18impl Lock {
19    pub const fn new() -> Self {
20        Self {
21            state: UnsafeCell::new(LockState::Unlocked),
22        }
23    }
24
25    pub const fn new_locked() -> Self {
26        Self {
27            state: UnsafeCell::new(LockState::Locked(ThreadList::new())),
28        }
29    }
30
31    pub fn is_locked(&self) -> bool {
32        interrupt::free(|_| {
33            let state = unsafe { &*self.state.get() };
34            match state {
35                LockState::Unlocked => false,
36                _ => true,
37            }
38        })
39    }
40
41    pub fn acquire(&self) {
42        interrupt::free(|cs| {
43            let state = unsafe { &mut *self.state.get() };
44            match state {
45                LockState::Unlocked => *state = LockState::Locked(ThreadList::new()),
46                LockState::Locked(waiters) => {
47                    unsafe { Threads::get_mut(cs) }.current_wait_on(waiters, ThreadState::LockWait);
48                }
49            }
50        })
51    }
52
53    pub fn try_acquire(&self) -> bool {
54        interrupt::free(|_| {
55            let state = unsafe { &mut *self.state.get() };
56            match state {
57                LockState::Unlocked => {
58                    *state = LockState::Locked(ThreadList::new());
59                    true
60                }
61                LockState::Locked(_) => false,
62            }
63        })
64    }
65
66    pub fn release(&self) {
67        interrupt::free(|cs| {
68            let state = unsafe { &mut *self.state.get() };
69            match state {
70                LockState::Unlocked => {} // TODO: panic?
71                LockState::Locked(waiters) => {
72                    if let Some(thread_id) = waiters.pop(cs) {
73                        //super::println!("unlocking {}", thread_id);
74                        unsafe { Threads::get_mut(cs) }.wake_pid(thread_id);
75                    } else {
76                        *state = LockState::Unlocked
77                    }
78                }
79            }
80        })
81    }
82}