#![cfg(feature = "num-complex")]
#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"num-complex\"] }")]
use crate::{
ffi,
ffi_ptr_ext::FfiPtrExt,
types::{any::PyAnyMethods, PyComplex},
Bound, FromPyObject, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject,
};
use num_complex::Complex;
use std::os::raw::c_double;
impl PyComplex {
#[cfg_attr(
not(feature = "gil-refs"),
deprecated(
since = "0.21.0",
note = "`PyComplex::from_complex` will be replaced by `PyComplex::from_complex_bound` in a future PyO3 version"
)
)]
pub fn from_complex<F: Into<c_double>>(py: Python<'_>, complex: Complex<F>) -> &PyComplex {
Self::from_complex_bound(py, complex).into_gil_ref()
}
pub fn from_complex_bound<F: Into<c_double>>(
py: Python<'_>,
complex: Complex<F>,
) -> Bound<'_, PyComplex> {
unsafe {
ffi::PyComplex_FromDoubles(complex.re.into(), complex.im.into())
.assume_owned(py)
.downcast_into_unchecked()
}
}
}
macro_rules! complex_conversion {
($float: ty) => {
#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl ToPyObject for Complex<$float> {
#[inline]
fn to_object(&self, py: Python<'_>) -> PyObject {
crate::IntoPy::<PyObject>::into_py(self.to_owned(), py)
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl crate::IntoPy<PyObject> for Complex<$float> {
fn into_py(self, py: Python<'_>) -> PyObject {
unsafe {
let raw_obj =
ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double);
PyObject::from_owned_ptr(py, raw_obj)
}
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl FromPyObject<'_> for Complex<$float> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Complex<$float>> {
#[cfg(not(any(Py_LIMITED_API, PyPy)))]
unsafe {
let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
if val.real == -1.0 {
if let Some(err) = PyErr::take(obj.py()) {
return Err(err);
}
}
Ok(Complex::new(val.real as $float, val.imag as $float))
}
#[cfg(any(Py_LIMITED_API, PyPy))]
unsafe {
let complex;
let obj = if obj.is_instance_of::<PyComplex>() {
obj
} else if let Some(method) =
obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
{
complex = method.call0()?;
&complex
} else {
obj
};
let ptr = obj.as_ptr();
let real = ffi::PyComplex_RealAsDouble(ptr);
if real == -1.0 {
if let Some(err) = PyErr::take(obj.py()) {
return Err(err);
}
}
let imag = ffi::PyComplex_ImagAsDouble(ptr);
Ok(Complex::new(real as $float, imag as $float))
}
}
}
};
}
complex_conversion!(f32);
complex_conversion!(f64);
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{complex::PyComplexMethods, PyModule};
#[test]
fn from_complex() {
Python::with_gil(|py| {
let complex = Complex::new(3.0, 1.2);
let py_c = PyComplex::from_complex_bound(py, complex);
assert_eq!(py_c.real(), 3.0);
assert_eq!(py_c.imag(), 1.2);
});
}
#[test]
fn to_from_complex() {
Python::with_gil(|py| {
let val = Complex::new(3.0, 1.2);
let obj = val.to_object(py);
assert_eq!(obj.extract::<Complex<f64>>(py).unwrap(), val);
});
}
#[test]
fn from_complex_err() {
Python::with_gil(|py| {
let obj = vec![1].to_object(py);
assert!(obj.extract::<Complex<f64>>(py).is_err());
});
}
#[test]
fn from_python_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code_bound(
py,
r#"
class A:
def __complex__(self): return 3.0+1.2j
class B:
def __float__(self): return 3.0
class C:
def __index__(self): return 3
"#,
"test.py",
"test",
)
.unwrap();
let from_complex = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
from_complex.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
let from_float = module.getattr("B").unwrap().call0().unwrap();
assert_eq!(
from_float.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
#[cfg(Py_3_8)]
{
let from_index = module.getattr("C").unwrap().call0().unwrap();
assert_eq!(
from_index.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
}
})
}
#[test]
fn from_python_inherited_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code_bound(
py,
r#"
class First: pass
class ComplexMixin:
def __complex__(self): return 3.0+1.2j
class FloatMixin:
def __float__(self): return 3.0
class IndexMixin:
def __index__(self): return 3
class A(First, ComplexMixin): pass
class B(First, FloatMixin): pass
class C(First, IndexMixin): pass
"#,
"test.py",
"test",
)
.unwrap();
let from_complex = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
from_complex.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
let from_float = module.getattr("B").unwrap().call0().unwrap();
assert_eq!(
from_float.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
#[cfg(Py_3_8)]
{
let from_index = module.getattr("C").unwrap().call0().unwrap();
assert_eq!(
from_index.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
}
})
}
#[test]
fn from_python_noncallable_descriptor_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code_bound(
py,
r#"
class A:
@property
def __complex__(self):
return lambda: 3.0+1.2j
"#,
"test.py",
"test",
)
.unwrap();
let obj = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
obj.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
})
}
#[test]
fn from_python_nondescriptor_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code_bound(
py,
r#"
class MyComplex:
def __call__(self): return 3.0+1.2j
class A:
__complex__ = MyComplex()
"#,
"test.py",
"test",
)
.unwrap();
let obj = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
obj.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
})
}
}