pyo3/conversions/
num_complex.rs1#![cfg(feature = "num-complex")]
2
3#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"num-complex\"] }")]
18use crate::{
97 ffi, ffi_ptr_ext::FfiPtrExt, types::PyComplex, Bound, FromPyObject, PyAny, PyErr, PyResult,
98 Python,
99};
100use num_complex::Complex;
101use std::ffi::c_double;
102
103impl PyComplex {
104 pub fn from_complex_bound<F: Into<c_double>>(
106 py: Python<'_>,
107 complex: Complex<F>,
108 ) -> Bound<'_, PyComplex> {
109 unsafe {
110 ffi::PyComplex_FromDoubles(complex.re.into(), complex.im.into())
111 .assume_owned(py)
112 .cast_into_unchecked()
113 }
114 }
115}
116
117macro_rules! complex_conversion {
118 ($float: ty) => {
119 #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
120 impl<'py> crate::conversion::IntoPyObject<'py> for Complex<$float> {
121 type Target = PyComplex;
122 type Output = Bound<'py, Self::Target>;
123 type Error = std::convert::Infallible;
124
125 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
126 unsafe {
127 Ok(
128 ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double)
129 .assume_owned(py)
130 .cast_into_unchecked(),
131 )
132 }
133 }
134 }
135
136 #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
137 impl<'py> crate::conversion::IntoPyObject<'py> for &Complex<$float> {
138 type Target = PyComplex;
139 type Output = Bound<'py, Self::Target>;
140 type Error = std::convert::Infallible;
141
142 #[inline]
143 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
144 (*self).into_pyobject(py)
145 }
146 }
147
148 #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
149 impl FromPyObject<'_> for Complex<$float> {
150 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Complex<$float>> {
151 #[cfg(not(any(Py_LIMITED_API, PyPy)))]
152 unsafe {
153 let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
154 if val.real == -1.0 {
155 if let Some(err) = PyErr::take(obj.py()) {
156 return Err(err);
157 }
158 }
159 Ok(Complex::new(val.real as $float, val.imag as $float))
160 }
161
162 #[cfg(any(Py_LIMITED_API, PyPy))]
163 unsafe {
164 use $crate::types::any::PyAnyMethods;
165 let complex;
166 let obj = if obj.is_instance_of::<PyComplex>() {
167 obj
168 } else if let Some(method) =
169 obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
170 {
171 complex = method.call0()?;
172 &complex
173 } else {
174 obj
178 };
179 let ptr = obj.as_ptr();
180 let real = ffi::PyComplex_RealAsDouble(ptr);
181 if real == -1.0 {
182 if let Some(err) = PyErr::take(obj.py()) {
183 return Err(err);
184 }
185 }
186 let imag = ffi::PyComplex_ImagAsDouble(ptr);
187 Ok(Complex::new(real as $float, imag as $float))
188 }
189 }
190 }
191 };
192}
193complex_conversion!(f32);
194complex_conversion!(f64);
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::test_utils::generate_unique_module_name;
200 use crate::types::PyAnyMethods as _;
201 use crate::types::{complex::PyComplexMethods, PyModule};
202 use crate::IntoPyObject;
203 use pyo3_ffi::c_str;
204
205 #[test]
206 fn from_complex() {
207 Python::attach(|py| {
208 let complex = Complex::new(3.0, 1.2);
209 let py_c = PyComplex::from_complex_bound(py, complex);
210 assert_eq!(py_c.real(), 3.0);
211 assert_eq!(py_c.imag(), 1.2);
212 });
213 }
214 #[test]
215 fn to_from_complex() {
216 Python::attach(|py| {
217 let val = Complex::new(3.0f64, 1.2);
218 let obj = val.into_pyobject(py).unwrap();
219 assert_eq!(obj.extract::<Complex<f64>>().unwrap(), val);
220 });
221 }
222 #[test]
223 fn from_complex_err() {
224 Python::attach(|py| {
225 let obj = vec![1i32].into_pyobject(py).unwrap();
226 assert!(obj.extract::<Complex<f64>>().is_err());
227 });
228 }
229 #[test]
230 fn from_python_magic() {
231 Python::attach(|py| {
232 let module = PyModule::from_code(
233 py,
234 c_str!(
235 r#"
236class A:
237 def __complex__(self): return 3.0+1.2j
238class B:
239 def __float__(self): return 3.0
240class C:
241 def __index__(self): return 3
242 "#
243 ),
244 c_str!("test.py"),
245 &generate_unique_module_name("test"),
246 )
247 .unwrap();
248 let from_complex = module.getattr("A").unwrap().call0().unwrap();
249 assert_eq!(
250 from_complex.extract::<Complex<f64>>().unwrap(),
251 Complex::new(3.0, 1.2)
252 );
253 let from_float = module.getattr("B").unwrap().call0().unwrap();
254 assert_eq!(
255 from_float.extract::<Complex<f64>>().unwrap(),
256 Complex::new(3.0, 0.0)
257 );
258 #[cfg(Py_3_8)]
260 {
261 let from_index = module.getattr("C").unwrap().call0().unwrap();
262 assert_eq!(
263 from_index.extract::<Complex<f64>>().unwrap(),
264 Complex::new(3.0, 0.0)
265 );
266 }
267 })
268 }
269 #[test]
270 fn from_python_inherited_magic() {
271 Python::attach(|py| {
272 let module = PyModule::from_code(
273 py,
274 c_str!(
275 r#"
276class First: pass
277class ComplexMixin:
278 def __complex__(self): return 3.0+1.2j
279class FloatMixin:
280 def __float__(self): return 3.0
281class IndexMixin:
282 def __index__(self): return 3
283class A(First, ComplexMixin): pass
284class B(First, FloatMixin): pass
285class C(First, IndexMixin): pass
286 "#
287 ),
288 c_str!("test.py"),
289 &generate_unique_module_name("test"),
290 )
291 .unwrap();
292 let from_complex = module.getattr("A").unwrap().call0().unwrap();
293 assert_eq!(
294 from_complex.extract::<Complex<f64>>().unwrap(),
295 Complex::new(3.0, 1.2)
296 );
297 let from_float = module.getattr("B").unwrap().call0().unwrap();
298 assert_eq!(
299 from_float.extract::<Complex<f64>>().unwrap(),
300 Complex::new(3.0, 0.0)
301 );
302 #[cfg(Py_3_8)]
303 {
304 let from_index = module.getattr("C").unwrap().call0().unwrap();
305 assert_eq!(
306 from_index.extract::<Complex<f64>>().unwrap(),
307 Complex::new(3.0, 0.0)
308 );
309 }
310 })
311 }
312 #[test]
313 fn from_python_noncallable_descriptor_magic() {
314 Python::attach(|py| {
318 let module = PyModule::from_code(
319 py,
320 c_str!(
321 r#"
322class A:
323 @property
324 def __complex__(self):
325 return lambda: 3.0+1.2j
326 "#
327 ),
328 c_str!("test.py"),
329 &generate_unique_module_name("test"),
330 )
331 .unwrap();
332 let obj = module.getattr("A").unwrap().call0().unwrap();
333 assert_eq!(
334 obj.extract::<Complex<f64>>().unwrap(),
335 Complex::new(3.0, 1.2)
336 );
337 })
338 }
339 #[test]
340 fn from_python_nondescriptor_magic() {
341 Python::attach(|py| {
343 let module = PyModule::from_code(
344 py,
345 c_str!(
346 r#"
347class MyComplex:
348 def __call__(self): return 3.0+1.2j
349class A:
350 __complex__ = MyComplex()
351 "#
352 ),
353 c_str!("test.py"),
354 &generate_unique_module_name("test"),
355 )
356 .unwrap();
357 let obj = module.getattr("A").unwrap().call0().unwrap();
358 assert_eq!(
359 obj.extract::<Complex<f64>>().unwrap(),
360 Complex::new(3.0, 1.2)
361 );
362 })
363 }
364}