mintex/
lib.rs

1//! mintex is a *min*imal Mutex.
2//!
3//! Most of the implementation is lifted from [`std::sync::Mutex`].
4//! The reason for this mutex existing is that I'd like a mutex which is
5//! quite lightweight and does not perform allocations.
6
7use std::cell::UnsafeCell;
8use std::fmt;
9use std::hint;
10use std::ops::Deref;
11use std::ops::DerefMut;
12use std::sync::atomic::AtomicBool;
13use std::sync::atomic::Ordering;
14use std::thread;
15
16// Empirically a good number on an M1
17const LOOP_LIMIT: usize = 250;
18
19/// Mutex implementation.
20pub struct Mutex<T: ?Sized> {
21    lock: AtomicBool,
22    data: UnsafeCell<T>,
23}
24
25unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
26unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
27
28impl<T> From<T> for Mutex<T> {
29    /// Creates a new mutex in an unlocked state ready for use.
30    /// This is equivalent to [`Mutex::new`].
31    fn from(t: T) -> Self {
32        Mutex::new(t)
33    }
34}
35
36impl<T: ?Sized + Default> Default for Mutex<T> {
37    /// Creates a `Mutex<T>`, with the `Default` value for T.
38    fn default() -> Mutex<T> {
39        Mutex::new(Default::default())
40    }
41}
42
43impl<T> Mutex<T> {
44    #[inline]
45    /// Create a new Mutex which wraps the provided data.
46    pub const fn new(data: T) -> Self {
47        Self {
48            lock: AtomicBool::new(false),
49            data: UnsafeCell::new(data),
50        }
51    }
52}
53
54impl<T: ?Sized> Mutex<T> {
55    /// Acquire a lock which returns a RAII MutexGuard over the locked data.
56    pub fn lock(&self) -> MutexGuard<'_, T> {
57        let mut loop_count = 0;
58        loop {
59            match self
60                .lock
61                .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
62            {
63                Ok(v) => {
64                    debug_assert!(!v);
65                    unsafe {
66                        return MutexGuard::new(self);
67                    }
68                }
69                Err(_e) => {
70                    if loop_count > LOOP_LIMIT {
71                        loop_count = 0;
72                        thread::yield_now();
73                    } else {
74                        loop_count += 1;
75                        hint::spin_loop();
76                    }
77                }
78            }
79        }
80    }
81    /// Unlock a mutex by dropping the MutexGuard.
82    pub fn unlock(guard: MutexGuard<'_, T>) {
83        drop(guard);
84    }
85}
86
87/// RAII Guard over locked data.
88pub struct MutexGuard<'a, T: ?Sized + 'a> {
89    mutex: &'a Mutex<T>,
90}
91
92// It would be nice to mark the MutexGuard as !Sync, but not stable yet.
93// impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
94unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
95
96impl<'mutex, T: ?Sized> MutexGuard<'mutex, T> {
97    unsafe fn new(mutex: &'mutex Mutex<T>) -> MutexGuard<'mutex, T> {
98        MutexGuard { mutex }
99    }
100}
101
102impl<T: ?Sized> Deref for MutexGuard<'_, T> {
103    type Target = T;
104
105    fn deref(&self) -> &T {
106        unsafe { &*self.mutex.data.get() }
107    }
108}
109
110impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
111    fn deref_mut(&mut self) -> &mut T {
112        unsafe { &mut *self.mutex.data.get() }
113    }
114}
115
116impl<T: ?Sized> Drop for MutexGuard<'_, T> {
117    #[inline]
118    fn drop(&mut self) {
119        self.mutex.lock.store(false, Ordering::Release);
120    }
121}
122
123impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        fmt::Debug::fmt(&**self, f)
126    }
127}
128
129impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        (**self).fmt(f)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use std::sync::mpsc::channel;
139    use std::sync::Arc;
140    use std::thread;
141
142    #[test]
143    fn exercise_mutex_lock() {
144        const N: usize = 100;
145
146        // Spawn a few threads to increment a shared variable (non-atomically), and
147        // let the main thread know once all increments are done.
148
149        let (tx, rx) = channel();
150
151        let data: usize = 0;
152
153        let my_lock = Arc::new(Mutex::new(data));
154
155        for _ in 0..N {
156            let tx = tx.clone();
157            let my_lock = my_lock.clone();
158            thread::spawn(move || {
159                let mut data = my_lock.lock();
160                *data += 1;
161                println!("after data: {}", data);
162                if *data == N {
163                    tx.send(()).unwrap();
164                }
165            });
166        }
167
168        rx.recv().unwrap();
169    }
170}