per_thread_mutex/lib.rs
1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5//! # per-thread-mutex
6//!
7//! Synchronization lock intended for thread unsafe C libraries.
8//!
9//! ## Rationale
10//!
11//! When working with certain C libraries, concurrent accesses are not safe. It can be problematic
12//! to model this at the Rust level largely because language level support can't enforce everything
13//! that's necessary to maintain safety in all cases.
14//!
15//! [`Send`][Send]/[`Sync`][Sync] can ensure that data structures are not used and sent across
16//! threads which provides part of the puzzle. However for certain cases thread-unsafe libraries
17//! can be used in a multithreaded context provided two conditions are upheld.
18//!
19//! 1. Data structures are thread-localized, meaning any resource that is created in a thread is
20//! never sent or used by another thread. This can be handled [`Send`]/[`Sync`].
21//! 2. There can be no concurrent calls into the library. This is not addressed by Rust language
22//! level features.
23//!
24//! This crate aims to address requirement 2.
25//!
26//! ## How is it used?
27//!
28//! The intended use of this mutex is with lazy_static as a global variable in Rust bindings for
29//! thread-unsafe C code. The mutex should be locked before each call into the library. This
30//! ensures that there are never any concurrent accesses from separate threads which could lead to
31//! unsafe behavior.
32//!
33//! ## How does it work?
34//!
35//! The lock keeps track of two pieces of data: the thread ID of the thread that currently has the
36//! lock acquisition and the number of acquisitions currently active on the lock. Acquisitions from
37//! the same thread ID are allowed at the same time and the lock available once all acquisitions
38//! of the lock are released.
39//!
40//! ## Why is the same thread permitted to acquire the mutex multiple times?
41//!
42//! This largely stems from C's heavy use of callbacks. If a callback is built into a C API, it is
43//! typical in Rust bindings to write the callback in Rust and to write a C shim to convert from C
44//! to Rust data types. Consider the case of an API call that, in its implementation, calls a
45//! callback where the callback also calls a Rust-wrapped API call. This is a safe usage of the
46//! library, but would result in a double acquisition of a traditional mutex guarding calls into
47//! the library. This lock allows both of those acquisitions to succeed without blocking,
48//! preventing the deadlock that would be caused by a traditional mutex while still guard against
49//! unsafe accesses of the library.
50
51use std::{
52 io,
53 sync::atomic::{AtomicU32, Ordering},
54};
55
56use libc::gettid;
57use log::trace;
58
59pub struct PerThreadMutex {
60 futex_word: AtomicU32,
61 thread_id: AtomicU32,
62 acquisitions: AtomicU32,
63}
64
65impl Default for PerThreadMutex {
66 /// Create a new mutex.
67 fn default() -> Self {
68 PerThreadMutex {
69 futex_word: AtomicU32::new(0),
70 thread_id: AtomicU32::new(0),
71 acquisitions: AtomicU32::new(0),
72 }
73 }
74}
75
76impl PerThreadMutex {
77 /// Acquire a per-thread lock.
78 ///
79 /// The lock keeps track of the thread ID from which it is called. If a second acquire is called
80 /// from the same mutex, `acquire()` will grant a lock to that caller as well. Number of
81 /// acquisitions is tracked internally and the lock will be released when all acquisitions are
82 /// dropped.
83 pub fn acquire(&self) -> PerThreadMutexGuard<'_> {
84 loop {
85 if self
86 .futex_word
87 .compare_exchange_weak(0, 1, Ordering::AcqRel, Ordering::Acquire)
88 == Ok(0)
89 {
90 let thread_id = unsafe { libc::gettid() } as u32;
91 assert_eq!(self.acquisitions.fetch_add(1, Ordering::AcqRel), 0);
92 assert_eq!(
93 self.thread_id.compare_exchange(
94 0,
95 thread_id,
96 Ordering::AcqRel,
97 Ordering::Acquire
98 ),
99 Ok(0)
100 );
101 trace!("[{}] Acquired initial lock", thread_id);
102 return PerThreadMutexGuard(self, thread_id);
103 } else {
104 let thread_id = unsafe { gettid() } as u32;
105 if self.thread_id.load(Ordering::Acquire) == thread_id {
106 let count = self.acquisitions.fetch_add(1, Ordering::AcqRel);
107 if count == u32::MAX {
108 panic!("Acquisition counter overflowed");
109 }
110 trace!("[{}] Acquired lock number {}", thread_id, count + 1);
111 return PerThreadMutexGuard(self, thread_id);
112 } else {
113 trace!("[{}] Thread is waiting", unsafe { libc::gettid() });
114 match unsafe {
115 libc::syscall(
116 libc::SYS_futex,
117 self.futex_word.as_ptr(),
118 libc::FUTEX_WAIT,
119 1,
120 0,
121 0,
122 0,
123 )
124 } {
125 0 => (),
126 _ => match io::Error::last_os_error().raw_os_error() {
127 Some(libc::EINTR | libc::EAGAIN) => (),
128 Some(libc::EACCES) => {
129 unreachable!("Local variable is always readable")
130 }
131 Some(i) => unreachable!(
132 "Only EAGAIN, EACCES, and EINTR are returned by FUTEX_WAIT; got {}",
133 i
134 ),
135 None => unreachable!(),
136 },
137 }
138 }
139 }
140 }
141 }
142}
143
144/// Guard indicating that the per-thread lock is still acquired. Dropping this lock causes all
145/// waiters to be woken up. This mutex is not fair so the lock will be acquired by
146/// the first thread that requests the acquisition.
147pub struct PerThreadMutexGuard<'a>(&'a PerThreadMutex, u32);
148
149impl Drop for PerThreadMutexGuard<'_> {
150 fn drop(&mut self) {
151 let acquisitions = self.0.acquisitions.fetch_sub(1, Ordering::AcqRel);
152 assert!(acquisitions > 0);
153 if acquisitions == 1 {
154 assert_eq!(
155 self.0
156 .thread_id
157 .compare_exchange(self.1, 0, Ordering::AcqRel, Ordering::Acquire),
158 Ok(self.1)
159 );
160 assert_eq!(
161 self.0
162 .futex_word
163 .compare_exchange(1, 0, Ordering::AcqRel, Ordering::Acquire),
164 Ok(1)
165 );
166 trace!("[{}] Unlocking mutex", self.1);
167 let i = unsafe {
168 libc::syscall(
169 libc::SYS_futex,
170 self.0.futex_word.as_ptr(),
171 libc::FUTEX_WAKE as i64,
172 libc::INT_MAX as i64,
173 0,
174 0,
175 0,
176 )
177 };
178 trace!("[{}] Number of waiters woken: {}", self.1, i);
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 use std::{sync::Arc, thread::spawn};
188
189 use env_logger::init;
190
191 #[test]
192 fn test_lock() {
193 init();
194
195 let mutex = Arc::new(PerThreadMutex::default());
196
197 let mutex_clone = Arc::clone(&mutex);
198 let handle1 = spawn(move || {
199 let _guard1 = mutex_clone.acquire();
200 let _guard2 = mutex_clone.acquire();
201 let _guard3 = mutex_clone.acquire();
202 });
203
204 let mutex_clone = Arc::clone(&mutex);
205 let handle2 = spawn(move || {
206 let _guard1 = mutex_clone.acquire();
207 let _guard2 = mutex_clone.acquire();
208 let _guard3 = mutex_clone.acquire();
209 let _guard4 = mutex_clone.acquire();
210 });
211
212 let mutex_clone = Arc::clone(&mutex);
213 let handle3 = spawn(move || {
214 let _guard1 = mutex_clone.acquire();
215 let _guard2 = mutex_clone.acquire();
216 });
217
218 let mutex_clone = Arc::clone(&mutex);
219 let handle4 = spawn(move || {
220 let _guard1 = mutex_clone.acquire();
221 let _guard2 = mutex_clone.acquire();
222 let _guard3 = mutex_clone.acquire();
223 let _guard4 = mutex_clone.acquire();
224 let _guard5 = mutex_clone.acquire();
225 });
226
227 let mutex_clone = Arc::clone(&mutex);
228 let handle5 = spawn(move || {
229 let _guard1 = mutex_clone.acquire();
230 });
231
232 for handle in [handle1, handle2, handle3, handle4, handle5] {
233 handle.join().unwrap();
234 }
235 }
236}