#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
use crate::types::PyMutex;
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
use crate::Python;
use crate::{types::PyAny, Bound};
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
use std::cell::UnsafeCell;
#[cfg(Py_GIL_DISABLED)]
struct CSGuard(crate::ffi::PyCriticalSection);
#[cfg(Py_GIL_DISABLED)]
impl Drop for CSGuard {
fn drop(&mut self) {
unsafe {
crate::ffi::PyCriticalSection_End(&mut self.0);
}
}
}
#[cfg(Py_GIL_DISABLED)]
struct CS2Guard(crate::ffi::PyCriticalSection2);
#[cfg(Py_GIL_DISABLED)]
impl Drop for CS2Guard {
fn drop(&mut self) {
unsafe {
crate::ffi::PyCriticalSection2_End(&mut self.0);
}
}
}
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
pub struct EnteredCriticalSection<'a, T>(&'a UnsafeCell<T>);
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
impl<T> EnteredCriticalSection<'_, T> {
pub unsafe fn get_mut(&mut self) -> &mut T {
unsafe { &mut *(self.0.get()) }
}
pub unsafe fn get(&self) -> &T {
unsafe { &*(self.0.get()) }
}
}
#[cfg_attr(not(Py_GIL_DISABLED), allow(unused_variables))]
pub fn with_critical_section<F, R>(object: &Bound<'_, PyAny>, f: F) -> R
where
F: FnOnce() -> R,
{
#[cfg(Py_GIL_DISABLED)]
{
let mut guard = CSGuard(unsafe { std::mem::zeroed() });
unsafe { crate::ffi::PyCriticalSection_Begin(&mut guard.0, object.as_ptr()) };
f()
}
#[cfg(not(Py_GIL_DISABLED))]
{
f()
}
}
#[cfg_attr(not(Py_GIL_DISABLED), allow(unused_variables))]
pub fn with_critical_section2<F, R>(a: &Bound<'_, PyAny>, b: &Bound<'_, PyAny>, f: F) -> R
where
F: FnOnce() -> R,
{
#[cfg(Py_GIL_DISABLED)]
{
let mut guard = CS2Guard(unsafe { std::mem::zeroed() });
unsafe { crate::ffi::PyCriticalSection2_Begin(&mut guard.0, a.as_ptr(), b.as_ptr()) };
f()
}
#[cfg(not(Py_GIL_DISABLED))]
{
f()
}
}
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
#[cfg_attr(not(Py_GIL_DISABLED), allow(unused_variables))]
pub fn with_critical_section_mutex<F, R, T>(_py: Python<'_>, mutex: &PyMutex<T>, f: F) -> R
where
F: for<'s> FnOnce(EnteredCriticalSection<'s, T>) -> R,
{
#[cfg(Py_GIL_DISABLED)]
{
let mut guard = CSGuard(unsafe { std::mem::zeroed() });
unsafe { crate::ffi::PyCriticalSection_BeginMutex(&mut guard.0, &mut *mutex.mutex.get()) };
f(EnteredCriticalSection(&mutex.data))
}
#[cfg(not(Py_GIL_DISABLED))]
{
f(EnteredCriticalSection(&mutex.data))
}
}
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
#[cfg_attr(not(Py_GIL_DISABLED), allow(unused_variables))]
pub fn with_critical_section_mutex2<F, R, T1, T2>(
_py: Python<'_>,
m1: &PyMutex<T1>,
m2: &PyMutex<T2>,
f: F,
) -> R
where
F: for<'s> FnOnce(EnteredCriticalSection<'s, T1>, EnteredCriticalSection<'s, T2>) -> R,
{
#[cfg(Py_GIL_DISABLED)]
{
let mut guard = CS2Guard(unsafe { std::mem::zeroed() });
unsafe {
crate::ffi::PyCriticalSection2_BeginMutex(
&mut guard.0,
&mut *m1.mutex.get(),
&mut *m2.mutex.get(),
)
};
f(
EnteredCriticalSection(&m1.data),
EnteredCriticalSection(&m2.data),
)
}
#[cfg(not(Py_GIL_DISABLED))]
{
f(
EnteredCriticalSection(&m1.data),
EnteredCriticalSection(&m2.data),
)
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(test)]
mod tests {
#[cfg(feature = "macros")]
use super::{with_critical_section, with_critical_section2};
#[cfg(all(not(Py_LIMITED_API), Py_3_14))]
use super::{with_critical_section_mutex, with_critical_section_mutex2};
#[cfg(all(not(Py_LIMITED_API), Py_3_14))]
use crate::types::PyMutex;
#[cfg(feature = "macros")]
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(any(feature = "macros", all(not(Py_LIMITED_API), Py_3_14)))]
use std::sync::Barrier;
#[cfg(feature = "macros")]
use crate::Py;
#[cfg(any(feature = "macros", all(not(Py_LIMITED_API), Py_3_14)))]
use crate::Python;
#[cfg(feature = "macros")]
#[crate::pyclass(crate = "crate")]
struct VecWrapper(Vec<isize>);
#[cfg(feature = "macros")]
#[crate::pyclass(crate = "crate")]
struct BoolWrapper(AtomicBool);
#[cfg(feature = "macros")]
#[test]
fn test_critical_section() {
let barrier = Barrier::new(2);
let bool_wrapper = Python::attach(|py| -> Py<BoolWrapper> {
Py::new(py, BoolWrapper(AtomicBool::new(false))).unwrap()
});
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
let b = bool_wrapper.bind(py);
with_critical_section(b, || {
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(10));
b.borrow().0.store(true, Ordering::Release);
})
});
});
s.spawn(|| {
barrier.wait();
Python::attach(|py| {
let b = bool_wrapper.bind(py);
with_critical_section(b, || {
assert!(b.borrow().0.load(Ordering::Acquire));
});
});
});
});
}
#[cfg(all(not(Py_LIMITED_API), Py_3_14))]
#[test]
fn test_critical_section_mutex() {
let barrier = Barrier::new(2);
let mutex = PyMutex::new(false);
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
with_critical_section_mutex(py, &mutex, |mut b| {
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(10));
*(unsafe { b.get_mut() }) = true;
});
});
});
s.spawn(|| {
barrier.wait();
Python::attach(|py| {
with_critical_section_mutex(py, &mutex, |b| {
assert!(unsafe { *b.get() });
});
});
});
});
}
#[cfg(feature = "macros")]
#[test]
fn test_critical_section2() {
let barrier = Barrier::new(3);
let (bool_wrapper1, bool_wrapper2) = Python::attach(|py| {
(
Py::new(py, BoolWrapper(AtomicBool::new(false))).unwrap(),
Py::new(py, BoolWrapper(AtomicBool::new(false))).unwrap(),
)
});
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
let b1 = bool_wrapper1.bind(py);
let b2 = bool_wrapper2.bind(py);
with_critical_section2(b1, b2, || {
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(10));
b1.borrow().0.store(true, Ordering::Release);
b2.borrow().0.store(true, Ordering::Release);
})
});
});
s.spawn(|| {
barrier.wait();
Python::attach(|py| {
let b1 = bool_wrapper1.bind(py);
with_critical_section(b1, || {
assert!(b1.borrow().0.load(Ordering::Acquire));
});
});
});
s.spawn(|| {
barrier.wait();
Python::attach(|py| {
let b2 = bool_wrapper2.bind(py);
with_critical_section(b2, || {
assert!(b2.borrow().0.load(Ordering::Acquire));
});
});
});
});
}
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
#[test]
fn test_critical_section_mutex2() {
let barrier = Barrier::new(2);
let m1 = PyMutex::new(false);
let m2 = PyMutex::new(false);
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
with_critical_section_mutex2(py, &m1, &m2, |mut b1, mut b2| {
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(10));
unsafe { (*b1.get_mut()) = true };
unsafe { (*b2.get_mut()) = true };
});
});
});
s.spawn(|| {
barrier.wait();
Python::attach(|py| {
with_critical_section_mutex2(py, &m1, &m2, |b1, b2| {
assert!(unsafe { *b1.get() });
assert!(unsafe { *b2.get() });
});
});
});
});
}
#[cfg(feature = "macros")]
#[test]
fn test_critical_section2_same_object_no_deadlock() {
let barrier = Barrier::new(2);
let bool_wrapper = Python::attach(|py| -> Py<BoolWrapper> {
Py::new(py, BoolWrapper(AtomicBool::new(false))).unwrap()
});
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
let b = bool_wrapper.bind(py);
with_critical_section2(b, b, || {
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(10));
b.borrow().0.store(true, Ordering::Release);
})
});
});
s.spawn(|| {
barrier.wait();
Python::attach(|py| {
let b = bool_wrapper.bind(py);
with_critical_section(b, || {
assert!(b.borrow().0.load(Ordering::Acquire));
});
});
});
});
}
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
#[test]
fn test_critical_section_mutex2_same_object_no_deadlock() {
let barrier = Barrier::new(2);
let m = PyMutex::new(false);
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
with_critical_section_mutex2(py, &m, &m, |mut b1, b2| {
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(10));
unsafe { (*b1.get_mut()) = true };
assert!(unsafe { *b2.get() });
});
});
});
s.spawn(|| {
barrier.wait();
Python::attach(|py| {
with_critical_section_mutex(py, &m, |b| {
assert!(unsafe { *b.get() });
});
});
});
});
}
#[cfg(feature = "macros")]
#[test]
fn test_critical_section2_two_containers() {
let (vec1, vec2) = Python::attach(|py| {
(
Py::new(py, VecWrapper(vec![1, 2, 3])).unwrap(),
Py::new(py, VecWrapper(vec![4, 5])).unwrap(),
)
});
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
let v1 = vec1.bind(py);
let v2 = vec2.bind(py);
with_critical_section2(v1, v2, || {
v2.borrow_mut().0.extend(v1.borrow().0.iter());
})
});
});
s.spawn(|| {
Python::attach(|py| {
let v1 = vec1.bind(py);
let v2 = vec2.bind(py);
with_critical_section2(v1, v2, || {
v1.borrow_mut().0.extend(v2.borrow().0.iter());
})
});
});
});
Python::attach(|py| {
let v1 = vec1.bind(py);
let v2 = vec2.bind(py);
let expected1_vec1 = vec![1, 2, 3, 4, 5];
let expected1_vec2 = vec![4, 5, 1, 2, 3, 4, 5];
let expected2_vec1 = vec![1, 2, 3, 4, 5, 1, 2, 3];
let expected2_vec2 = vec![4, 5, 1, 2, 3];
assert!(
(v1.borrow().0.eq(&expected1_vec1) && v2.borrow().0.eq(&expected1_vec2))
|| (v1.borrow().0.eq(&expected2_vec1) && v2.borrow().0.eq(&expected2_vec2))
);
});
}
#[cfg(all(Py_3_14, not(Py_LIMITED_API)))]
#[test]
fn test_critical_section_mutex2_two_containers() {
let (m1, m2) = (PyMutex::new(vec![1, 2, 3]), PyMutex::new(vec![4, 5]));
let (m1_guard, m2_guard) = (m1.lock().unwrap(), m2.lock().unwrap());
std::thread::scope(|s| {
s.spawn(|| {
Python::attach(|py| {
with_critical_section_mutex2(py, &m1, &m2, |mut v1, v2| {
let vec1 = unsafe { v1.get_mut() };
let vec2 = unsafe { v2.get() };
vec1.extend(vec2.iter());
})
});
});
s.spawn(|| {
Python::attach(|py| {
with_critical_section_mutex2(py, &m1, &m2, |v1, mut v2| {
let vec1 = unsafe { v1.get() };
let vec2 = unsafe { v2.get_mut() };
vec2.extend(vec1.iter());
})
});
});
Python::attach(|_| {
#[cfg(Py_GIL_DISABLED)]
{
assert_eq!(&*m1_guard, &[1, 2, 3]);
assert_eq!(&*m2_guard, &[4, 5]);
}
});
drop(m1_guard);
drop(m2_guard);
});
let expected1_vec1 = vec![1, 2, 3, 4, 5];
let expected1_vec2 = vec![4, 5, 1, 2, 3, 4, 5];
let expected2_vec1 = vec![1, 2, 3, 4, 5, 1, 2, 3];
let expected2_vec2 = vec![4, 5, 1, 2, 3];
let v1 = m1.lock().unwrap();
let v2 = m2.lock().unwrap();
assert!(
(&*v1, &*v2) == (&expected1_vec1, &expected1_vec2)
|| (&*v1, &*v2) == (&expected2_vec1, &expected2_vec2)
);
}
}