use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
#[cfg(panic = "unwind")]
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{LockResult, PoisonError};
#[cfg(panic = "unwind")]
use std::thread;
struct Flag {
#[cfg(panic = "unwind")]
failed: AtomicBool,
}
impl Flag {
#[inline]
const fn new() -> Flag {
Flag {
#[cfg(panic = "unwind")]
failed: AtomicBool::new(false),
}
}
#[inline]
fn borrow(&self) -> LockResult<()> {
if self.get() {
Err(PoisonError::new(()))
} else {
Ok(())
}
}
#[inline]
fn guard(&self) -> LockResult<Guard> {
let ret = Guard {
#[cfg(panic = "unwind")]
panicking: thread::panicking(),
};
if self.get() {
Err(PoisonError::new(ret))
} else {
Ok(ret)
}
}
#[inline]
#[cfg(panic = "unwind")]
fn done(&self, guard: &Guard) {
if !guard.panicking && thread::panicking() {
self.failed.store(true, Ordering::Relaxed);
}
}
#[inline]
#[cfg(not(panic = "unwind"))]
fn done(&self, _guard: &Guard) {}
#[inline]
#[cfg(panic = "unwind")]
fn get(&self) -> bool {
self.failed.load(Ordering::Relaxed)
}
#[inline(always)]
#[cfg(not(panic = "unwind"))]
fn get(&self) -> bool {
false
}
#[inline]
fn clear(&self) {
#[cfg(panic = "unwind")]
self.failed.store(false, Ordering::Relaxed)
}
}
#[derive(Clone)]
pub(crate) struct Guard {
#[cfg(panic = "unwind")]
panicking: bool,
}
pub struct PyMutex<T: ?Sized> {
pub(crate) mutex: UnsafeCell<crate::ffi::PyMutex>,
poison: Flag,
pub(crate) data: UnsafeCell<T>,
}
pub struct PyMutexGuard<'a, T: ?Sized> {
inner: &'a PyMutex<T>,
poison: Guard,
_phantom: PhantomData<*const ()>,
}
unsafe impl<T: ?Sized + Sync> Sync for PyMutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Send> Send for PyMutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for PyMutex<T> {}
impl<T> PyMutex<T> {
pub fn lock(&self) -> LockResult<PyMutexGuard<'_, T>> {
unsafe { crate::ffi::PyMutex_Lock(UnsafeCell::raw_get(&self.mutex)) };
PyMutexGuard::new(self)
}
pub const fn new(value: T) -> Self {
Self {
mutex: UnsafeCell::new(crate::ffi::PyMutex::new()),
data: UnsafeCell::new(value),
poison: Flag::new(),
}
}
#[cfg(Py_3_14)]
pub fn is_locked(&self) -> bool {
let ret = unsafe { crate::ffi::PyMutex_IsLocked(UnsafeCell::raw_get(&self.mutex)) };
ret != 0
}
pub fn into_inner(self) -> LockResult<T>
where
T: Sized,
{
let data = self.data.into_inner();
map_result(self.poison.borrow(), |()| data)
}
pub fn clear_poison(&self) {
self.poison.clear();
}
}
#[cfg_attr(not(panic = "unwind"), allow(clippy::unnecessary_wraps))]
fn map_result<T, U, F>(result: LockResult<T>, f: F) -> LockResult<U>
where
F: FnOnce(T) -> U,
{
match result {
Ok(t) => Ok(f(t)),
#[cfg(panic = "unwind")]
Err(e) => Err(PoisonError::new(f(e.into_inner()))),
#[cfg(not(panic = "unwind"))]
Err(_) => {
unreachable!();
}
}
}
impl<'mutex, T: ?Sized> PyMutexGuard<'mutex, T> {
fn new(lock: &'mutex PyMutex<T>) -> LockResult<PyMutexGuard<'mutex, T>> {
map_result(lock.poison.guard(), |guard| PyMutexGuard {
inner: lock,
poison: guard,
_phantom: PhantomData,
})
}
}
impl<'a, T: ?Sized> Drop for PyMutexGuard<'a, T> {
fn drop(&mut self) {
unsafe {
self.inner.poison.done(&self.poison);
crate::ffi::PyMutex_Unlock(UnsafeCell::raw_get(&self.inner.mutex))
};
}
}
impl<'a, T> Deref for PyMutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.inner.data.get() }
}
}
impl<'a, T> DerefMut for PyMutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.inner.data.get() }
}
}
#[cfg(test)]
mod tests {
#[cfg(not(target_arch = "wasm32"))]
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Barrier,
};
use super::*;
#[cfg(not(target_arch = "wasm32"))]
use crate::types::{PyAnyMethods, PyDict, PyDictMethods, PyNone};
#[cfg(not(target_arch = "wasm32"))]
use crate::Py;
#[cfg(not(target_arch = "wasm32"))]
use crate::Python;
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_pymutex() {
let mutex = Python::attach(|py| -> PyMutex<Py<PyDict>> {
let d = PyDict::new(py);
PyMutex::new(d.unbind())
});
#[cfg_attr(not(Py_3_14), allow(unused_variables))]
let mutex = Python::attach(|py| {
let mutex = py.detach(|| -> PyMutex<Py<PyDict>> {
std::thread::spawn(|| {
let dict_guard = mutex.lock().unwrap();
Python::attach(|py| {
let dict = dict_guard.bind(py);
dict.set_item(PyNone::get(py), PyNone::get(py)).unwrap();
});
#[cfg(Py_3_14)]
assert!(mutex.is_locked());
drop(dict_guard);
#[cfg(Py_3_14)]
assert!(!mutex.is_locked());
mutex
})
.join()
.unwrap()
});
let dict_guard = mutex.lock().unwrap();
#[cfg(Py_3_14)]
assert!(mutex.is_locked());
let d = dict_guard.bind(py);
assert!(d
.get_item(PyNone::get(py))
.unwrap()
.unwrap()
.eq(PyNone::get(py))
.unwrap());
#[cfg(Py_3_14)]
assert!(mutex.is_locked());
drop(dict_guard);
#[cfg(Py_3_14)]
assert!(!mutex.is_locked());
mutex
});
#[cfg(Py_3_14)]
assert!(!mutex.is_locked());
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_pymutex_blocks() {
let mutex = PyMutex::new(());
let first_thread_locked_once = AtomicBool::new(false);
let second_thread_locked_once = AtomicBool::new(false);
let finished = AtomicBool::new(false);
let barrier = Barrier::new(2);
std::thread::scope(|s| {
s.spawn(|| {
let guard = mutex.lock();
first_thread_locked_once.store(true, Ordering::SeqCst);
while !finished.load(Ordering::SeqCst) {
if second_thread_locked_once.load(Ordering::SeqCst) {
std::thread::sleep(std::time::Duration::from_millis(10));
barrier.wait();
finished.store(true, Ordering::SeqCst);
}
}
drop(guard);
});
s.spawn(|| {
while !first_thread_locked_once.load(Ordering::SeqCst) {
std::hint::spin_loop();
}
second_thread_locked_once.store(true, Ordering::SeqCst);
let guard = mutex.lock();
assert!(finished.load(Ordering::SeqCst));
drop(guard);
});
barrier.wait();
});
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_recover_poison() {
let mutex = Python::attach(|py| -> PyMutex<Py<PyDict>> {
let d = PyDict::new(py);
d.set_item("hello", "world").unwrap();
PyMutex::new(d.unbind())
});
let lock = Arc::new(mutex);
let lock2 = Arc::clone(&lock);
let _ = thread::spawn(move || {
let _guard = lock2.lock().unwrap();
panic!();
})
.join();
let guard = match lock.lock() {
Ok(_) => {
unreachable!();
}
Err(poisoned) => poisoned.into_inner(),
};
Python::attach(|py| {
assert!(
(*guard)
.bind(py)
.get_item("hello")
.unwrap()
.unwrap()
.extract::<&str>()
.unwrap()
== "world"
);
});
let mutex = PyMutex::new(0);
assert_eq!(mutex.into_inner().unwrap(), 0);
let mutex = PyMutex::new(0);
let _ = std::thread::scope(|s| {
s.spawn(|| {
let _guard = mutex.lock().unwrap();
panic!();
})
.join()
});
match mutex.into_inner() {
Ok(_) => {
unreachable!()
}
Err(e) => {
assert!(e.into_inner() == 0)
}
}
let mutex = PyMutex::new(0);
let _ = std::thread::scope(|s| {
s.spawn(|| {
let _guard = mutex.lock().unwrap();
panic!();
})
.join()
});
mutex.clear_poison();
assert_eq!(*mutex.lock().unwrap(), 0);
}
#[test]
fn test_send_not_send() {
use crate::impl_::pyclass::{value_of, IsSend, IsSync};
assert!(!value_of!(IsSend, PyMutexGuard<'_, i32>));
assert!(value_of!(IsSync, PyMutexGuard<'_, i32>));
assert!(value_of!(IsSend, PyMutex<i32>));
assert!(value_of!(IsSync, PyMutex<i32>));
}
}