pyo3-arraylike 0.5.0

Convenience extension for rust-numpy
Documentation
use std::ffi::CStr;

use crate::{ArrayLike, PyArrayLike0, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
use ndarray::{Array0, array};
use numpy::{
    get_array_module,
    pyo3::{PyAny, Python, types::IntoPyDict},
};
use pyo3::{Bound, ffi::c_str, types::PyAnyMethods};

fn eval<'py>(py: Python<'py>, code: &CStr) -> Bound<'py, PyAny> {
    py.eval(
        code,
        Some(
            &[("np", get_array_module(py).unwrap())]
                .into_py_dict(py)
                .expect("module `numpy` not found"),
        ),
        None,
    )
    .unwrap()
}

#[test]
fn extract_reference() {
    Python::attach(|py| {
        let py_array = eval(py, c_str!("np.array([[1,2],[3,4]], dtype='float64')"));
        let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap();

        assert!(matches!(extracted_array.0, ArrayLike::PyRef(_)));
        assert_eq!(
            array![[1_f64, 2_f64], [3_f64, 4_f64]],
            extracted_array.into_owned_array()
        );
    });
}

#[test]
fn convert_array_on_extract() {
    Python::attach(|py| {
        let py_array = eval(py, c_str!("np.array([[1,2],[3,4]], dtype='int')"));
        let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap();

        assert!(matches!(extracted_array.0, ArrayLike::Owned(_, _)));
        assert_eq!(
            array![[1_f64, 2_f64], [3_f64, 4_f64]],
            extracted_array.into_owned_array()
        );
    });
}

#[test]
fn convert_list_on_extract() {
    Python::attach(|py| {
        let py_list = eval(py, c_str!("[[1,2],[3,4]]"));
        let extracted_array = py_list.extract::<PyArrayLike2<i32>>().unwrap();

        assert!(matches!(extracted_array.0, ArrayLike::Owned(_, _)));
        assert_eq!(array![[1, 2], [3, 4]], extracted_array.into_owned_array());
    });
}

#[test]
fn convert_array_in_list_on_extract() {
    Python::attach(|py| {
        let py_array = eval(py, c_str!("[np.array([1, 2], dtype='int32'), [3, 4]]"));
        let extracted_array = py_array.extract::<PyArrayLike2<i32>>().unwrap();

        assert!(matches!(extracted_array.0, ArrayLike::Owned(_, _)));
        assert_eq!(array![[1, 2], [3, 4]], extracted_array.into_owned_array());
    });
}

#[test]
fn convert_list_on_extract_dyn() {
    Python::attach(|py| {
        let py_list = eval(py, c_str!("[[[1,2],[3,4]],[[5,6],[7,8]]]"));
        let extracted_array = py_list.extract::<PyArrayLikeDyn<i32>>().unwrap();

        assert!(matches!(extracted_array.0, ArrayLike::Owned(_, _)));
        assert_eq!(
            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
            extracted_array.into_owned_array()
        );
    });
}

#[test]
fn convert_1d_list_on_extract() {
    Python::attach(|py| {
        let py_list = eval(py, c_str!("[1,2,3,4]"));
        let extracted_array_1d = py_list.extract::<PyArrayLike1<u32>>().unwrap();
        let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<f64>>().unwrap();

        assert!(matches!(extracted_array_1d.0, ArrayLike::Owned(_, _)));
        assert!(matches!(extracted_array_dyn.0, ArrayLike::Owned(_, _)));
        assert_eq!(array![1, 2, 3, 4], extracted_array_1d.into_owned_array());
        assert_eq!(
            array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(),
            extracted_array_dyn.into_owned_array()
        );
    });
}

#[test]
fn unsafe_cast_shall_fail() {
    Python::attach(|py| {
        let py_list = eval(py, c_str!("np.array([1.1,2.2,3.3,4.4], dtype='float64')"));
        let extracted_array = py_list.extract::<PyArrayLike1<i32>>();

        assert!(extracted_array.is_err());
    });
}

#[test]
fn extract_0d_array() {
    Python::attach(|py| {
        let array0 = eval(py, c_str!("np.array(1, dtype='int64')"));
        let num = eval(py, c_str!("42"));

        let extraction1 = array0.extract::<PyArrayLike0<i32>>().unwrap();
        let extraction2 = num.extract::<PyArrayLike0<i32>>().unwrap();
        let extraction3 = num.extract::<PyArrayLikeDyn<usize>>().unwrap();

        assert!(matches!(extraction1.0, ArrayLike::Owned(_, _)));
        assert!(matches!(extraction2.0, ArrayLike::Owned(_, _)));
        assert!(matches!(extraction3.0, ArrayLike::Owned(_, _)));

        assert_eq!(extraction1.into_owned_array(), Array0::from_elem((), 1));
        assert_eq!(extraction2.into_owned_array(), Array0::from_elem((), 42));
        assert_eq!(
            extraction3.into_owned_array().into_dyn(),
            Array0::from_elem((), 42).into_dyn()
        );
    });
}