use crate::Python;
use crate::{ffi, AsPyPointer, PyAny};
use crate::{pyobject_native_type_core, PyErr, PyResult};
use std::ffi::{c_void, CStr, CString};
use std::os::raw::c_int;
use std::thread::{self, ThreadId};
#[repr(transparent)]
pub struct PyCapsule(PyAny);
pyobject_native_type_core!(PyCapsule, ffi::PyCapsule_Type, #checkfunction=ffi::PyCapsule_CheckExact);
impl PyCapsule {
pub fn new<'py, T: 'static + Send + AssertNotZeroSized>(
py: Python<'py>,
value: T,
name: &CStr,
) -> PyResult<&'py Self> {
Self::new_with_destructor(py, value, name, |_, _| {})
}
pub fn new_with_destructor<
'py,
T: 'static + Send + AssertNotZeroSized,
F: FnOnce(T, *mut c_void),
>(
py: Python<'py>,
value: T,
name: &CStr,
destructor: F,
) -> PyResult<&'py Self> {
AssertNotZeroSized::assert_not_zero_sized(&value);
let name = name.to_owned();
let name_ptr = name.as_ptr();
let thread_id = thread::current().id();
let val = Box::new(CapsuleContents {
value,
destructor,
thread_id,
name,
});
let cap_ptr = unsafe {
ffi::PyCapsule_New(
Box::into_raw(val) as *mut c_void,
name_ptr,
Some(capsule_destructor::<T, F>),
)
};
unsafe { py.from_owned_ptr_or_err(cap_ptr) }
}
pub unsafe fn import<'py, T>(py: Python<'py>, name: &CStr) -> PyResult<&'py T> {
let ptr = ffi::PyCapsule_Import(name.as_ptr(), false as c_int);
if ptr.is_null() {
Err(PyErr::fetch(py))
} else {
Ok(&*(ptr as *const T))
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn set_context(&self, py: Python<'_>, context: *mut c_void) -> PyResult<()> {
let result = unsafe { ffi::PyCapsule_SetContext(self.as_ptr(), context) as u8 };
if result != 0 {
Err(PyErr::fetch(py))
} else {
Ok(())
}
}
pub fn get_context(&self, py: Python<'_>) -> PyResult<*mut c_void> {
let ctx = unsafe { ffi::PyCapsule_GetContext(self.as_ptr()) };
if ctx.is_null() && self.is_valid() && PyErr::occurred(py) {
Err(PyErr::fetch(py))
} else {
Ok(ctx)
}
}
pub unsafe fn reference<T>(&self) -> &T {
&*(self.pointer() as *const T)
}
pub fn pointer(&self) -> *mut c_void {
unsafe { ffi::PyCapsule_GetPointer(self.0.as_ptr(), self.name().as_ptr()) }
}
pub fn is_valid(&self) -> bool {
let r = unsafe { ffi::PyCapsule_IsValid(self.as_ptr(), self.name().as_ptr()) } as u8;
r != 0
}
pub fn name(&self) -> &CStr {
unsafe {
let ptr = ffi::PyCapsule_GetName(self.as_ptr());
if ptr.is_null() {
ffi::PyErr_Clear();
CStr::from_bytes_with_nul_unchecked(b"\0")
} else {
CStr::from_ptr(ptr)
}
}
}
}
#[repr(C)]
struct CapsuleContents<T: 'static + Send, D: FnOnce(T, *mut c_void)> {
value: T,
destructor: D,
thread_id: ThreadId,
name: CString,
}
unsafe extern "C" fn capsule_destructor<T: 'static + Send, F: FnOnce(T, *mut c_void)>(
capsule: *mut ffi::PyObject,
) {
let ptr = ffi::PyCapsule_GetPointer(capsule, ffi::PyCapsule_GetName(capsule));
let ctx = ffi::PyCapsule_GetContext(capsule);
let CapsuleContents {
value,
destructor,
thread_id,
..
} = *Box::from_raw(ptr as *mut CapsuleContents<T, F>);
if thread_id != thread::current().id() {
ffi::PyErr_WarnEx(
ffi::PyExc_RuntimeWarning,
b"capsule destructor called in thread other than the one the capsule was created in, skipping the destructor\0".as_ptr().cast(),
1,
);
if !ffi::PyErr_Occurred().is_null() {
ffi::PyErr_WriteUnraisable(ffi::_Py_NewRef(ffi::Py_None()));
}
return;
}
destructor(value, ctx)
}
#[doc(hidden)]
pub trait AssertNotZeroSized: Sized {
const _CONDITION: usize = (std::mem::size_of::<Self>() == 0) as usize;
const _CHECK: &'static str =
["PyCapsule value type T must not be zero-sized!"][Self::_CONDITION];
#[allow(path_statements, clippy::no_effect)]
fn assert_not_zero_sized(&self) {
<Self as AssertNotZeroSized>::_CHECK;
}
}
impl<T> AssertNotZeroSized for T {}
#[cfg(test)]
mod tests {
use libc::c_void;
use crate::prelude::PyModule;
use crate::{types::PyCapsule, Py, PyResult, Python};
use std::ffi::CString;
use std::sync::mpsc::{channel, Sender};
#[test]
fn test_pycapsule_struct() -> PyResult<()> {
#[repr(C)]
struct Foo {
pub val: u32,
}
impl Foo {
fn get_val(&self) -> u32 {
self.val
}
}
Python::with_gil(|py| -> PyResult<()> {
let foo = Foo { val: 123 };
let name = CString::new("foo").unwrap();
let cap = PyCapsule::new(py, foo, &name)?;
assert!(cap.is_valid());
let foo_capi = unsafe { cap.reference::<Foo>() };
assert_eq!(foo_capi.val, 123);
assert_eq!(foo_capi.get_val(), 123);
assert_eq!(cap.name(), name.as_ref());
Ok(())
})
}
#[test]
fn test_pycapsule_func() {
fn foo(x: u32) -> u32 {
x
}
let cap: Py<PyCapsule> = Python::with_gil(|py| {
let name = CString::new("foo").unwrap();
let cap = PyCapsule::new(py, foo as fn(u32) -> u32, &name).unwrap();
cap.into()
});
Python::with_gil(|py| {
let f = unsafe { cap.as_ref(py).reference::<fn(u32) -> u32>() };
assert_eq!(f(123), 123);
});
}
#[test]
fn test_pycapsule_context() -> PyResult<()> {
Python::with_gil(|py| {
let name = CString::new("foo").unwrap();
let cap = PyCapsule::new(py, 0, &name)?;
let c = cap.get_context(py)?;
assert!(c.is_null());
let ctx = Box::new(123_u32);
cap.set_context(py, Box::into_raw(ctx) as _)?;
let ctx_ptr: *mut c_void = cap.get_context(py)?;
let ctx = unsafe { *Box::from_raw(ctx_ptr as *mut u32) };
assert_eq!(ctx, 123);
Ok(())
})
}
#[test]
fn test_pycapsule_import() -> PyResult<()> {
#[repr(C)]
struct Foo {
pub val: u32,
}
Python::with_gil(|py| -> PyResult<()> {
let foo = Foo { val: 123 };
let name = CString::new("builtins.capsule").unwrap();
let capsule = PyCapsule::new(py, foo, &name)?;
let module = PyModule::import(py, "builtins")?;
module.add("capsule", capsule)?;
let wrong_name = CString::new("builtins.non_existant").unwrap();
let result: PyResult<&Foo> = unsafe { PyCapsule::import(py, wrong_name.as_ref()) };
assert!(result.is_err());
let cap: &Foo = unsafe { PyCapsule::import(py, name.as_ref())? };
assert_eq!(cap.val, 123);
Ok(())
})
}
#[test]
fn test_vec_storage() {
let cap: Py<PyCapsule> = Python::with_gil(|py| {
let name = CString::new("foo").unwrap();
let stuff: Vec<u8> = vec![1, 2, 3, 4];
let cap = PyCapsule::new(py, stuff, &name).unwrap();
cap.into()
});
Python::with_gil(|py| {
let ctx: &Vec<u8> = unsafe { cap.as_ref(py).reference() };
assert_eq!(ctx, &[1, 2, 3, 4]);
})
}
#[test]
fn test_vec_context() {
let context: Vec<u8> = vec![1, 2, 3, 4];
let cap: Py<PyCapsule> = Python::with_gil(|py| {
let name = CString::new("foo").unwrap();
let cap = PyCapsule::new(py, 0, &name).unwrap();
cap.set_context(py, Box::into_raw(Box::new(&context)) as _)
.unwrap();
cap.into()
});
Python::with_gil(|py| {
let ctx_ptr: *mut c_void = cap.as_ref(py).get_context(py).unwrap();
let ctx = unsafe { *Box::from_raw(ctx_ptr as *mut &Vec<u8>) };
assert_eq!(ctx, &vec![1_u8, 2, 3, 4]);
})
}
#[test]
fn test_pycapsule_destructor() {
let (tx, rx) = channel::<bool>();
fn destructor(_val: u32, ctx: *mut c_void) {
assert!(!ctx.is_null());
let context = unsafe { *Box::from_raw(ctx as *mut Sender<bool>) };
context.send(true).unwrap();
}
Python::with_gil(|py| {
let name = CString::new("foo").unwrap();
let cap =
PyCapsule::new_with_destructor(py, 0, &name, destructor as fn(u32, *mut c_void))
.unwrap();
cap.set_context(py, Box::into_raw(Box::new(tx)) as _)
.unwrap();
});
assert_eq!(rx.recv(), Ok(true));
}
}