pyo3/conversions/
num_complex.rs

1#![cfg(feature = "num-complex")]
2
3//!  Conversions to and from [num-complex](https://docs.rs/num-complex)’
4//! [`Complex`]`<`[`f32`]`>` and [`Complex`]`<`[`f64`]`>`.
5//!
6//! num-complex’ [`Complex`] supports more operations than PyO3's [`PyComplex`]
7//! and can be used with the rest of the Rust ecosystem.
8//!
9//! # Setup
10//!
11//! To use this feature, add this to your **`Cargo.toml`**:
12//!
13//! ```toml
14//! [dependencies]
15//! # change * to the latest versions
16//! num-complex = "*"
17#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"num-complex\"] }")]
18//! ```
19//!
20//! Note that you must use compatible versions of num-complex and PyO3.
21//! The required num-complex version may vary based on the version of PyO3.
22//!
23//! # Examples
24//!
25//! Using [num-complex](https://docs.rs/num-complex) and [nalgebra](https://docs.rs/nalgebra)
26//! to create a pyfunction that calculates the eigenvalues of a 2x2 matrix.
27//! ```ignore
28//! # // not tested because nalgebra isn't supported on msrv
29//! # // please file an issue if it breaks!
30//! use nalgebra::base::{dimension::Const, Matrix};
31//! use num_complex::Complex;
32//! use pyo3::prelude::*;
33//!
34//! type T = Complex<f64>;
35//!
36//! #[pyfunction]
37//! fn get_eigenvalues(m11: T, m12: T, m21: T, m22: T) -> Vec<T> {
38//!     let mat = Matrix::<T, Const<2>, Const<2>, _>::new(m11, m12, m21, m22);
39//!
40//!     match mat.eigenvalues() {
41//!         Some(e) => e.data.as_slice().to_vec(),
42//!         None => vec![],
43//!     }
44//! }
45//!
46//! #[pymodule]
47//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
48//!     m.add_function(wrap_pyfunction!(get_eigenvalues, m)?)?;
49//!     Ok(())
50//! }
51//! # // test
52//! # use assert_approx_eq::assert_approx_eq;
53//! # use nalgebra::ComplexField;
54//! # use pyo3::types::PyComplex;
55//! #
56//! # fn main() -> PyResult<()> {
57//! #     Python::attach(|py| -> PyResult<()> {
58//! #         let module = PyModule::new(py, "my_module")?;
59//! #
60//! #         module.add_function(&wrap_pyfunction!(get_eigenvalues, module)?)?;
61//! #
62//! #         let m11 = PyComplex::from_doubles(py, 0_f64, -1_f64);
63//! #         let m12 = PyComplex::from_doubles(py, 1_f64, 0_f64);
64//! #         let m21 = PyComplex::from_doubles(py, 2_f64, -1_f64);
65//! #         let m22 = PyComplex::from_doubles(py, -1_f64, 0_f64);
66//! #
67//! #         let result = module
68//! #             .getattr("get_eigenvalues")?
69//! #             .call1((m11, m12, m21, m22))?;
70//! #         println!("eigenvalues: {:?}", result);
71//! #
72//! #         let result = result.extract::<Vec<T>>()?;
73//! #         let e0 = result[0];
74//! #         let e1 = result[1];
75//! #
76//! #         assert_approx_eq!(e0, Complex::new(1_f64, -1_f64));
77//! #         assert_approx_eq!(e1, Complex::new(-2_f64, 0_f64));
78//! #
79//! #         Ok(())
80//! #     })
81//! # }
82//! ```
83//!
84//! Python code:
85//! ```python
86//! from my_module import get_eigenvalues
87//!
88//! m11 = complex(0,-1)
89//! m12 = complex(1,0)
90//! m21 = complex(2,-1)
91//! m22 = complex(-1,0)
92//!
93//! result = get_eigenvalues(m11,m12,m21,m22)
94//! assert result == [complex(1,-1), complex(-2,0)]
95//! ```
96use crate::{
97    ffi, ffi_ptr_ext::FfiPtrExt, types::PyComplex, Bound, FromPyObject, PyAny, PyErr, PyResult,
98    Python,
99};
100use num_complex::Complex;
101use std::ffi::c_double;
102
103impl PyComplex {
104    /// Creates a new Python `PyComplex` object from `num_complex`'s [`Complex`].
105    pub fn from_complex_bound<F: Into<c_double>>(
106        py: Python<'_>,
107        complex: Complex<F>,
108    ) -> Bound<'_, PyComplex> {
109        unsafe {
110            ffi::PyComplex_FromDoubles(complex.re.into(), complex.im.into())
111                .assume_owned(py)
112                .cast_into_unchecked()
113        }
114    }
115}
116
117macro_rules! complex_conversion {
118    ($float: ty) => {
119        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
120        impl<'py> crate::conversion::IntoPyObject<'py> for Complex<$float> {
121            type Target = PyComplex;
122            type Output = Bound<'py, Self::Target>;
123            type Error = std::convert::Infallible;
124
125            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
126                unsafe {
127                    Ok(
128                        ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double)
129                            .assume_owned(py)
130                            .cast_into_unchecked(),
131                    )
132                }
133            }
134        }
135
136        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
137        impl<'py> crate::conversion::IntoPyObject<'py> for &Complex<$float> {
138            type Target = PyComplex;
139            type Output = Bound<'py, Self::Target>;
140            type Error = std::convert::Infallible;
141
142            #[inline]
143            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
144                (*self).into_pyobject(py)
145            }
146        }
147
148        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
149        impl FromPyObject<'_> for Complex<$float> {
150            fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Complex<$float>> {
151                #[cfg(not(any(Py_LIMITED_API, PyPy)))]
152                unsafe {
153                    let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
154                    if val.real == -1.0 {
155                        if let Some(err) = PyErr::take(obj.py()) {
156                            return Err(err);
157                        }
158                    }
159                    Ok(Complex::new(val.real as $float, val.imag as $float))
160                }
161
162                #[cfg(any(Py_LIMITED_API, PyPy))]
163                unsafe {
164                    use $crate::types::any::PyAnyMethods;
165                    let complex;
166                    let obj = if obj.is_instance_of::<PyComplex>() {
167                        obj
168                    } else if let Some(method) =
169                        obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
170                    {
171                        complex = method.call0()?;
172                        &complex
173                    } else {
174                        // `obj` might still implement `__float__` or `__index__`, which will be
175                        // handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any
176                        // errors if those methods don't exist / raise exceptions.
177                        obj
178                    };
179                    let ptr = obj.as_ptr();
180                    let real = ffi::PyComplex_RealAsDouble(ptr);
181                    if real == -1.0 {
182                        if let Some(err) = PyErr::take(obj.py()) {
183                            return Err(err);
184                        }
185                    }
186                    let imag = ffi::PyComplex_ImagAsDouble(ptr);
187                    Ok(Complex::new(real as $float, imag as $float))
188                }
189            }
190        }
191    };
192}
193complex_conversion!(f32);
194complex_conversion!(f64);
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::test_utils::generate_unique_module_name;
200    use crate::types::PyAnyMethods as _;
201    use crate::types::{complex::PyComplexMethods, PyModule};
202    use crate::IntoPyObject;
203    use pyo3_ffi::c_str;
204
205    #[test]
206    fn from_complex() {
207        Python::attach(|py| {
208            let complex = Complex::new(3.0, 1.2);
209            let py_c = PyComplex::from_complex_bound(py, complex);
210            assert_eq!(py_c.real(), 3.0);
211            assert_eq!(py_c.imag(), 1.2);
212        });
213    }
214    #[test]
215    fn to_from_complex() {
216        Python::attach(|py| {
217            let val = Complex::new(3.0f64, 1.2);
218            let obj = val.into_pyobject(py).unwrap();
219            assert_eq!(obj.extract::<Complex<f64>>().unwrap(), val);
220        });
221    }
222    #[test]
223    fn from_complex_err() {
224        Python::attach(|py| {
225            let obj = vec![1i32].into_pyobject(py).unwrap();
226            assert!(obj.extract::<Complex<f64>>().is_err());
227        });
228    }
229    #[test]
230    fn from_python_magic() {
231        Python::attach(|py| {
232            let module = PyModule::from_code(
233                py,
234                c_str!(
235                    r#"
236class A:
237    def __complex__(self): return 3.0+1.2j
238class B:
239    def __float__(self): return 3.0
240class C:
241    def __index__(self): return 3
242                "#
243                ),
244                c_str!("test.py"),
245                &generate_unique_module_name("test"),
246            )
247            .unwrap();
248            let from_complex = module.getattr("A").unwrap().call0().unwrap();
249            assert_eq!(
250                from_complex.extract::<Complex<f64>>().unwrap(),
251                Complex::new(3.0, 1.2)
252            );
253            let from_float = module.getattr("B").unwrap().call0().unwrap();
254            assert_eq!(
255                from_float.extract::<Complex<f64>>().unwrap(),
256                Complex::new(3.0, 0.0)
257            );
258            // Before Python 3.8, `__index__` wasn't tried by `float`/`complex`.
259            #[cfg(Py_3_8)]
260            {
261                let from_index = module.getattr("C").unwrap().call0().unwrap();
262                assert_eq!(
263                    from_index.extract::<Complex<f64>>().unwrap(),
264                    Complex::new(3.0, 0.0)
265                );
266            }
267        })
268    }
269    #[test]
270    fn from_python_inherited_magic() {
271        Python::attach(|py| {
272            let module = PyModule::from_code(
273                py,
274                c_str!(
275                    r#"
276class First: pass
277class ComplexMixin:
278    def __complex__(self): return 3.0+1.2j
279class FloatMixin:
280    def __float__(self): return 3.0
281class IndexMixin:
282    def __index__(self): return 3
283class A(First, ComplexMixin): pass
284class B(First, FloatMixin): pass
285class C(First, IndexMixin): pass
286                "#
287                ),
288                c_str!("test.py"),
289                &generate_unique_module_name("test"),
290            )
291            .unwrap();
292            let from_complex = module.getattr("A").unwrap().call0().unwrap();
293            assert_eq!(
294                from_complex.extract::<Complex<f64>>().unwrap(),
295                Complex::new(3.0, 1.2)
296            );
297            let from_float = module.getattr("B").unwrap().call0().unwrap();
298            assert_eq!(
299                from_float.extract::<Complex<f64>>().unwrap(),
300                Complex::new(3.0, 0.0)
301            );
302            #[cfg(Py_3_8)]
303            {
304                let from_index = module.getattr("C").unwrap().call0().unwrap();
305                assert_eq!(
306                    from_index.extract::<Complex<f64>>().unwrap(),
307                    Complex::new(3.0, 0.0)
308                );
309            }
310        })
311    }
312    #[test]
313    fn from_python_noncallable_descriptor_magic() {
314        // Functions and lambdas implement the descriptor protocol in a way that makes
315        // `type(inst).attr(inst)` equivalent to `inst.attr()` for methods, but this isn't the only
316        // way the descriptor protocol might be implemented.
317        Python::attach(|py| {
318            let module = PyModule::from_code(
319                py,
320                c_str!(
321                    r#"
322class A:
323    @property
324    def __complex__(self):
325        return lambda: 3.0+1.2j
326                "#
327                ),
328                c_str!("test.py"),
329                &generate_unique_module_name("test"),
330            )
331            .unwrap();
332            let obj = module.getattr("A").unwrap().call0().unwrap();
333            assert_eq!(
334                obj.extract::<Complex<f64>>().unwrap(),
335                Complex::new(3.0, 1.2)
336            );
337        })
338    }
339    #[test]
340    fn from_python_nondescriptor_magic() {
341        // Magic methods don't need to implement the descriptor protocol, if they're callable.
342        Python::attach(|py| {
343            let module = PyModule::from_code(
344                py,
345                c_str!(
346                    r#"
347class MyComplex:
348    def __call__(self): return 3.0+1.2j
349class A:
350    __complex__ = MyComplex()
351                "#
352                ),
353                c_str!("test.py"),
354                &generate_unique_module_name("test"),
355            )
356            .unwrap();
357            let obj = module.getattr("A").unwrap().call0().unwrap();
358            assert_eq!(
359                obj.extract::<Complex<f64>>().unwrap(),
360                Complex::new(3.0, 1.2)
361            );
362        })
363    }
364}