use crate::err::{self, PyDowncastError, PyErr, PyResult};
use crate::gil::{self, GILGuard, GILPool};
use crate::type_object::{PyTypeInfo, PyTypeObject};
use crate::types::{PyAny, PyDict, PyModule, PyType};
use crate::{ffi, AsPyPointer, FromPyPointer, IntoPyPointer, PyNativeType, PyObject, PyTryFrom};
use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::os::raw::{c_char, c_int};
#[derive(Debug)]
pub struct PythonVersionInfo<'py> {
pub major: u8,
pub minor: u8,
pub patch: u8,
pub suffix: Option<&'py str>,
}
impl<'py> PythonVersionInfo<'py> {
fn from_str(version_number_str: &'py str) -> Self {
fn split_and_parse_number(version_part: &str) -> (u8, Option<&str>) {
match version_part.find(|c: char| !c.is_ascii_digit()) {
None => (version_part.parse().unwrap(), None),
Some(version_part_suffix_start) => {
let (version_part, version_part_suffix) =
version_part.split_at(version_part_suffix_start);
(version_part.parse().unwrap(), Some(version_part_suffix))
}
}
}
let mut parts = version_number_str.split('.');
let major_str = parts.next().expect("Python major version missing");
let minor_str = parts.next().expect("Python minor version missing");
let patch_str = parts.next();
assert!(
parts.next().is_none(),
"Python version string has too many parts"
);
let major = major_str
.parse()
.expect("Python major version not an integer");
let (minor, suffix) = split_and_parse_number(minor_str);
if suffix.is_some() {
assert!(patch_str.is_none());
return PythonVersionInfo {
major,
minor,
patch: 0,
suffix,
};
}
let (patch, suffix) = patch_str.map(split_and_parse_number).unwrap_or_default();
PythonVersionInfo {
major,
minor,
patch,
suffix,
}
}
}
impl PartialEq<(u8, u8)> for PythonVersionInfo<'_> {
fn eq(&self, other: &(u8, u8)) -> bool {
self.major == other.0 && self.minor == other.1
}
}
impl PartialEq<(u8, u8, u8)> for PythonVersionInfo<'_> {
fn eq(&self, other: &(u8, u8, u8)) -> bool {
self.major == other.0 && self.minor == other.1 && self.patch == other.2
}
}
impl PartialOrd<(u8, u8)> for PythonVersionInfo<'_> {
fn partial_cmp(&self, other: &(u8, u8)) -> Option<std::cmp::Ordering> {
(self.major, self.minor).partial_cmp(other)
}
}
impl PartialOrd<(u8, u8, u8)> for PythonVersionInfo<'_> {
fn partial_cmp(&self, other: &(u8, u8, u8)) -> Option<std::cmp::Ordering> {
(self.major, self.minor, self.patch).partial_cmp(other)
}
}
#[derive(Copy, Clone)]
pub struct Python<'py>(PhantomData<&'py GILGuard>);
impl Python<'_> {
#[cfg_attr(
not(PyPy),
doc = "[`prepare_freethreaded_python`](crate::prepare_freethreaded_python)"
)]
#[cfg_attr(PyPy, doc = "`prepare_freethreaded_python`")]
#[inline]
pub fn with_gil<F, R>(f: F) -> R
where
F: for<'py> FnOnce(Python<'py>) -> R,
{
f(unsafe { gil::ensure_gil().python() })
}
#[cfg_attr(
all(Py_3_8, not(PyPy)),
doc = "[`_Py_InitializeMain`](crate::ffi::_Py_InitializeMain)"
)]
#[cfg_attr(any(not(Py_3_8), PyPy), doc = "`_Py_InitializeMain`")]
#[inline]
pub unsafe fn with_gil_unchecked<F, R>(f: F) -> R
where
F: for<'py> FnOnce(Python<'py>) -> R,
{
f(gil::ensure_gil_unchecked().python())
}
}
impl<'py> Python<'py> {
#[cfg_attr(
not(PyPy),
doc = "[`prepare_freethreaded_python`](crate::prepare_freethreaded_python)"
)]
#[cfg_attr(PyPy, doc = "`prepare_freethreaded_python`")]
#[inline]
pub fn acquire_gil() -> GILGuard {
GILGuard::acquire()
}
pub fn allow_threads<T, F>(self, f: F) -> T
where
F: Send + FnOnce() -> T,
T: Send,
{
struct RestoreGuard {
count: usize,
tstate: *mut ffi::PyThreadState,
}
impl Drop for RestoreGuard {
fn drop(&mut self) {
gil::GIL_COUNT.with(|c| c.set(self.count));
unsafe {
ffi::PyEval_RestoreThread(self.tstate);
}
}
}
let count = gil::GIL_COUNT.with(|c| c.replace(0));
let tstate = unsafe { ffi::PyEval_SaveThread() };
let _guard = RestoreGuard { count, tstate };
f()
}
pub fn eval(
self,
code: &str,
globals: Option<&PyDict>,
locals: Option<&PyDict>,
) -> PyResult<&'py PyAny> {
self.run_code(code, ffi::Py_eval_input, globals, locals)
}
pub fn run(
self,
code: &str,
globals: Option<&PyDict>,
locals: Option<&PyDict>,
) -> PyResult<()> {
let res = self.run_code(code, ffi::Py_file_input, globals, locals);
res.map(|obj| {
debug_assert!(obj.is_none());
})
}
fn run_code(
self,
code: &str,
start: c_int,
globals: Option<&PyDict>,
locals: Option<&PyDict>,
) -> PyResult<&'py PyAny> {
let code = CString::new(code)?;
unsafe {
let mptr = ffi::PyImport_AddModule("__main__\0".as_ptr() as *const _);
if mptr.is_null() {
return Err(PyErr::fetch(self));
}
let globals = globals
.map(AsPyPointer::as_ptr)
.unwrap_or_else(|| ffi::PyModule_GetDict(mptr));
let locals = locals.map(AsPyPointer::as_ptr).unwrap_or(globals);
let code_obj = ffi::Py_CompileString(code.as_ptr(), "<string>\0".as_ptr() as _, start);
if code_obj.is_null() {
return Err(PyErr::fetch(self));
}
let res_ptr = ffi::PyEval_EvalCode(code_obj, globals, locals);
ffi::Py_DECREF(code_obj);
self.from_owned_ptr_or_err(res_ptr)
}
}
pub fn get_type<T>(self) -> &'py PyType
where
T: PyTypeObject,
{
T::type_object(self)
}
pub fn import(self, name: &str) -> PyResult<&'py PyModule> {
PyModule::import(self, name)
}
#[allow(non_snake_case)] #[inline]
pub fn None(self) -> PyObject {
unsafe { PyObject::from_borrowed_ptr(self, ffi::Py_None()) }
}
#[allow(non_snake_case)] #[inline]
pub fn NotImplemented(self) -> PyObject {
unsafe { PyObject::from_borrowed_ptr(self, ffi::Py_NotImplemented()) }
}
pub fn version(self) -> &'py str {
unsafe {
CStr::from_ptr(ffi::Py_GetVersion() as *const c_char)
.to_str()
.expect("Python version string not UTF-8")
}
}
pub fn version_info(self) -> PythonVersionInfo<'py> {
let version_str = self.version();
let version_number_str = version_str.split(' ').next().unwrap_or(version_str);
PythonVersionInfo::from_str(version_number_str)
}
pub fn checked_cast_as<T>(self, obj: PyObject) -> Result<&'py T, PyDowncastError<'py>>
where
T: PyTryFrom<'py>,
{
let any: &PyAny = unsafe { self.from_owned_ptr(obj.into_ptr()) };
<T as PyTryFrom>::try_from(any)
}
pub unsafe fn cast_as<T>(self, obj: PyObject) -> &'py T
where
T: PyNativeType + PyTypeInfo,
{
let any: &PyAny = self.from_owned_ptr(obj.into_ptr());
T::unchecked_downcast(any)
}
#[allow(clippy::wrong_self_convention)]
pub unsafe fn from_owned_ptr<T>(self, ptr: *mut ffi::PyObject) -> &'py T
where
T: FromPyPointer<'py>,
{
FromPyPointer::from_owned_ptr(self, ptr)
}
#[allow(clippy::wrong_self_convention)]
pub unsafe fn from_owned_ptr_or_err<T>(self, ptr: *mut ffi::PyObject) -> PyResult<&'py T>
where
T: FromPyPointer<'py>,
{
FromPyPointer::from_owned_ptr_or_err(self, ptr)
}
#[allow(clippy::wrong_self_convention)]
pub unsafe fn from_owned_ptr_or_opt<T>(self, ptr: *mut ffi::PyObject) -> Option<&'py T>
where
T: FromPyPointer<'py>,
{
FromPyPointer::from_owned_ptr_or_opt(self, ptr)
}
#[allow(clippy::wrong_self_convention)]
pub unsafe fn from_borrowed_ptr<T>(self, ptr: *mut ffi::PyObject) -> &'py T
where
T: FromPyPointer<'py>,
{
FromPyPointer::from_borrowed_ptr(self, ptr)
}
#[allow(clippy::wrong_self_convention)]
pub unsafe fn from_borrowed_ptr_or_err<T>(self, ptr: *mut ffi::PyObject) -> PyResult<&'py T>
where
T: FromPyPointer<'py>,
{
FromPyPointer::from_borrowed_ptr_or_err(self, ptr)
}
#[allow(clippy::wrong_self_convention)]
pub unsafe fn from_borrowed_ptr_or_opt<T>(self, ptr: *mut ffi::PyObject) -> Option<&'py T>
where
T: FromPyPointer<'py>,
{
FromPyPointer::from_borrowed_ptr_or_opt(self, ptr)
}
pub fn check_signals(self) -> PyResult<()> {
let v = unsafe { ffi::PyErr_CheckSignals() };
err::error_on_minusone(self, v)
}
#[inline]
pub unsafe fn assume_gil_acquired() -> Python<'py> {
Python(PhantomData)
}
#[inline]
pub unsafe fn new_pool(self) -> GILPool {
GILPool::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{IntoPyDict, PyList};
#[test]
fn test_eval() {
Python::with_gil(|py| {
let v: i32 = py
.eval("min(1, 2)", None, None)
.map_err(|e| e.print(py))
.unwrap()
.extract()
.unwrap();
assert_eq!(v, 1);
let d = [("foo", 13)].into_py_dict(py);
let v: i32 = py
.eval("foo + 29", Some(d), None)
.unwrap()
.extract()
.unwrap();
assert_eq!(v, 42);
let v: i32 = py
.eval("foo + 29", None, Some(d))
.unwrap()
.extract()
.unwrap();
assert_eq!(v, 42);
let v: i32 = py
.eval("min(foo, 2)", None, Some(d))
.unwrap()
.extract()
.unwrap();
assert_eq!(v, 2);
});
}
#[test]
fn test_allow_threads_releases_and_acquires_gil() {
Python::with_gil(|py| {
let b = std::sync::Arc::new(std::sync::Barrier::new(2));
let b2 = b.clone();
std::thread::spawn(move || Python::with_gil(|_| b2.wait()));
py.allow_threads(|| {
b.wait();
});
unsafe {
let tstate = ffi::PyEval_SaveThread();
ffi::PyEval_RestoreThread(tstate);
}
});
}
#[test]
fn test_allow_threads_panics_safely() {
Python::with_gil(|py| {
let result = std::panic::catch_unwind(|| unsafe {
let py = Python::assume_gil_acquired();
py.allow_threads(|| {
panic!("There was a panic!");
});
});
assert!(result.is_err());
let list = PyList::new(py, &[1, 2, 3, 4]);
assert_eq!(list.extract::<Vec<i32>>().unwrap(), vec![1, 2, 3, 4]);
});
}
#[test]
fn test_python_version_info() {
Python::with_gil(|py| {
let version = py.version_info();
#[cfg(Py_3_6)]
assert!(version >= (3, 6));
#[cfg(Py_3_6)]
assert!(version >= (3, 6, 0));
#[cfg(Py_3_7)]
assert!(version >= (3, 7));
#[cfg(Py_3_7)]
assert!(version >= (3, 7, 0));
#[cfg(Py_3_8)]
assert!(version >= (3, 8));
#[cfg(Py_3_8)]
assert!(version >= (3, 8, 0));
#[cfg(Py_3_9)]
assert!(version >= (3, 9));
#[cfg(Py_3_9)]
assert!(version >= (3, 9, 0));
});
}
#[test]
fn test_python_version_info_parse() {
assert!(PythonVersionInfo::from_str("3.5.0a1") >= (3, 5, 0));
assert!(PythonVersionInfo::from_str("3.5+") >= (3, 5, 0));
assert!(PythonVersionInfo::from_str("3.5+") == (3, 5, 0));
assert!(PythonVersionInfo::from_str("3.5+") != (3, 5, 1));
assert!(PythonVersionInfo::from_str("3.5.2a1+") < (3, 5, 3));
assert!(PythonVersionInfo::from_str("3.5.2a1+") == (3, 5, 2));
assert!(PythonVersionInfo::from_str("3.5.2a1+") == (3, 5));
assert!(PythonVersionInfo::from_str("3.5+") == (3, 5));
assert!(PythonVersionInfo::from_str("3.5.2a1+") < (3, 6));
assert!(PythonVersionInfo::from_str("3.5.2a1+") > (3, 4));
}
#[test]
#[cfg(not(Py_LIMITED_API))]
fn test_acquire_gil() {
const GIL_NOT_HELD: c_int = 0;
const GIL_HELD: c_int = 1;
let state = unsafe { crate::ffi::PyGILState_Check() };
assert_eq!(state, GIL_NOT_HELD);
{
let gil = Python::acquire_gil();
let _py = gil.python();
let state = unsafe { crate::ffi::PyGILState_Check() };
assert_eq!(state, GIL_HELD);
drop(gil);
}
let state = unsafe { crate::ffi::PyGILState_Check() };
assert_eq!(state, GIL_NOT_HELD);
}
}