per_thread_mutex/
lib.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5//! # per-thread-mutex
6//!
7//! Synchronization lock intended for thread unsafe C libraries.
8//!
9//! ## Rationale
10//!
11//! When working with certain C libraries, concurrent accesses are not safe. It can be problematic
12//! to model this at the Rust level largely because language level support can't enforce everything
13//! that's necessary to maintain safety in all cases.
14//!
15//! [`Send`][Send]/[`Sync`][Sync] can ensure that data structures are not used and sent across
16//! threads which provides part of the puzzle. However for certain cases thread-unsafe libraries
17//! can be used in a multithreaded context provided two conditions are upheld.
18//!
19//! 1. Data structures are thread-localized, meaning any resource that is created in a thread is
20//!    never sent or used by another thread. This can be handled [`Send`]/[`Sync`].
21//! 2. There can be no concurrent calls into the library. This is not addressed by Rust language
22//!    level features.
23//!
24//! This crate aims to address requirement 2.
25//!
26//! ## How is it used?
27//!
28//! The intended use of this mutex is with lazy_static as a global variable in Rust bindings for
29//! thread-unsafe C code. The mutex should be locked before each call into the library. This
30//! ensures that there are never any concurrent accesses from separate threads which could lead to
31//! unsafe behavior.
32//!
33//! ## How does it work?
34//!
35//! The lock keeps track of two pieces of data: the thread ID of the thread that currently has the
36//! lock acquisition and the number of acquisitions currently active on the lock. Acquisitions from
37//! the same thread ID are allowed at the same time and the lock available once all acquisitions
38//! of the lock are released.
39//!
40//! ## Why is the same thread permitted to acquire the mutex multiple times?
41//!
42//! This largely stems from C's heavy use of callbacks. If a callback is built into a C API, it is
43//! typical in Rust bindings to write the callback in Rust and to write a C shim to convert from C
44//! to Rust data types. Consider the case of an API call that, in its implementation, calls a
45//! callback where the callback also calls a Rust-wrapped API call. This is a safe usage of the
46//! library, but would result in a double acquisition of a traditional mutex guarding calls into
47//! the library. This lock allows both of those acquisitions to succeed without blocking,
48//! preventing the deadlock that would be caused by a traditional mutex while still guard against
49//! unsafe accesses of the library.
50
51use std::{
52    io,
53    sync::atomic::{AtomicU32, Ordering},
54};
55
56use libc::gettid;
57use log::trace;
58
59pub struct PerThreadMutex {
60    futex_word: AtomicU32,
61    thread_id: AtomicU32,
62    acquisitions: AtomicU32,
63}
64
65impl Default for PerThreadMutex {
66    /// Create a new mutex.
67    fn default() -> Self {
68        PerThreadMutex {
69            futex_word: AtomicU32::new(0),
70            thread_id: AtomicU32::new(0),
71            acquisitions: AtomicU32::new(0),
72        }
73    }
74}
75
76impl PerThreadMutex {
77    /// Acquire a per-thread lock.
78    ///
79    /// The lock keeps track of the thread ID from which it is called. If a second acquire is called
80    /// from the same mutex, `acquire()` will grant a lock to that caller as well. Number of
81    /// acquisitions is tracked internally and the lock will be released when all acquisitions are
82    /// dropped.
83    pub fn acquire(&self) -> PerThreadMutexGuard<'_> {
84        loop {
85            if self
86                .futex_word
87                .compare_exchange_weak(0, 1, Ordering::AcqRel, Ordering::Acquire)
88                == Ok(0)
89            {
90                let thread_id = unsafe { libc::gettid() } as u32;
91                assert_eq!(self.acquisitions.fetch_add(1, Ordering::AcqRel), 0);
92                assert_eq!(
93                    self.thread_id.compare_exchange(
94                        0,
95                        thread_id,
96                        Ordering::AcqRel,
97                        Ordering::Acquire
98                    ),
99                    Ok(0)
100                );
101                trace!("[{}] Acquired initial lock", thread_id);
102                return PerThreadMutexGuard(self, thread_id);
103            } else {
104                let thread_id = unsafe { gettid() } as u32;
105                if self.thread_id.load(Ordering::Acquire) == thread_id {
106                    let count = self.acquisitions.fetch_add(1, Ordering::AcqRel);
107                    if count == u32::MAX {
108                        panic!("Acquisition counter overflowed");
109                    }
110                    trace!("[{}] Acquired lock number {}", thread_id, count + 1);
111                    return PerThreadMutexGuard(self, thread_id);
112                } else {
113                    trace!("[{}] Thread is waiting", unsafe { libc::gettid() });
114                    match unsafe {
115                        libc::syscall(
116                            libc::SYS_futex,
117                            self.futex_word.as_ptr(),
118                            libc::FUTEX_WAIT,
119                            1,
120                            0,
121                            0,
122                            0,
123                        )
124                    } {
125                        0 => (),
126                        _ => match io::Error::last_os_error().raw_os_error() {
127                            Some(libc::EINTR | libc::EAGAIN) => (),
128                            Some(libc::EACCES) => {
129                                unreachable!("Local variable is always readable")
130                            }
131                            Some(i) => unreachable!(
132                                "Only EAGAIN, EACCES, and EINTR are returned by FUTEX_WAIT; got {}",
133                                i
134                            ),
135                            None => unreachable!(),
136                        },
137                    }
138                }
139            }
140        }
141    }
142}
143
144/// Guard indicating that the per-thread lock is still acquired. Dropping this lock causes all
145/// waiters to be woken up. This mutex is not fair so the lock will be acquired by
146/// the first thread that requests the acquisition.
147pub struct PerThreadMutexGuard<'a>(&'a PerThreadMutex, u32);
148
149impl Drop for PerThreadMutexGuard<'_> {
150    fn drop(&mut self) {
151        let acquisitions = self.0.acquisitions.fetch_sub(1, Ordering::AcqRel);
152        assert!(acquisitions > 0);
153        if acquisitions == 1 {
154            assert_eq!(
155                self.0
156                    .thread_id
157                    .compare_exchange(self.1, 0, Ordering::AcqRel, Ordering::Acquire),
158                Ok(self.1)
159            );
160            assert_eq!(
161                self.0
162                    .futex_word
163                    .compare_exchange(1, 0, Ordering::AcqRel, Ordering::Acquire),
164                Ok(1)
165            );
166            trace!("[{}] Unlocking mutex", self.1);
167            let i = unsafe {
168                libc::syscall(
169                    libc::SYS_futex,
170                    self.0.futex_word.as_ptr(),
171                    libc::FUTEX_WAKE as i64,
172                    libc::INT_MAX as i64,
173                    0,
174                    0,
175                    0,
176                )
177            };
178            trace!("[{}] Number of waiters woken: {}", self.1, i);
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    use std::{sync::Arc, thread::spawn};
188
189    use env_logger::init;
190
191    #[test]
192    fn test_lock() {
193        init();
194
195        let mutex = Arc::new(PerThreadMutex::default());
196
197        let mutex_clone = Arc::clone(&mutex);
198        let handle1 = spawn(move || {
199            let _guard1 = mutex_clone.acquire();
200            let _guard2 = mutex_clone.acquire();
201            let _guard3 = mutex_clone.acquire();
202        });
203
204        let mutex_clone = Arc::clone(&mutex);
205        let handle2 = spawn(move || {
206            let _guard1 = mutex_clone.acquire();
207            let _guard2 = mutex_clone.acquire();
208            let _guard3 = mutex_clone.acquire();
209            let _guard4 = mutex_clone.acquire();
210        });
211
212        let mutex_clone = Arc::clone(&mutex);
213        let handle3 = spawn(move || {
214            let _guard1 = mutex_clone.acquire();
215            let _guard2 = mutex_clone.acquire();
216        });
217
218        let mutex_clone = Arc::clone(&mutex);
219        let handle4 = spawn(move || {
220            let _guard1 = mutex_clone.acquire();
221            let _guard2 = mutex_clone.acquire();
222            let _guard3 = mutex_clone.acquire();
223            let _guard4 = mutex_clone.acquire();
224            let _guard5 = mutex_clone.acquire();
225        });
226
227        let mutex_clone = Arc::clone(&mutex);
228        let handle5 = spawn(move || {
229            let _guard1 = mutex_clone.acquire();
230        });
231
232        for handle in [handle1, handle2, handle3, handle4, handle5] {
233            handle.join().unwrap();
234        }
235    }
236}