use approx::assert_abs_diff_eq;
use pyo3::prelude::*;
use pyo3::types::{PyAnyMethods, PyList, PyModule};
use rustpython_vm::{AsObject, Interpreter, builtins::PyList as RpyList};
fn rumpy_interp() -> Interpreter {
let b = Interpreter::builder(Default::default());
let def = rumpy::module_def(&b.ctx);
b.add_native_module(def).build()
}
#[derive(Debug)]
struct Out {
shape: Vec<usize>,
data: Vec<f64>,
}
fn rumpy_run(source: &str) -> Out {
let interp = rumpy_interp();
interp
.enter(|vm| -> Result<Out, String> {
let scope = vm.new_scope_with_builtins();
let code = vm
.compile(source, rustpython_vm::compiler::Mode::Exec, "<t>".into())
.map_err(|e| format!("compile: {e}"))?;
vm.run_code_obj(code, scope.clone())
.map_err(|e| pyerr(vm, &e))?;
let r = scope.globals.get_item("result", vm).expect("set result");
extract(&r, vm).map_err(|e| pyerr(vm, &e))
})
.unwrap_or_else(|e| panic!("rumpy: {e}\n--- src ---\n{source}"))
}
fn pyerr(
vm: &rustpython_vm::VirtualMachine,
e: &rustpython_vm::PyRef<rustpython_vm::builtins::PyBaseException>,
) -> String {
let mut s = String::new();
let _ = vm.write_exception(&mut s, e);
s
}
fn extract(
obj: &rustpython_vm::PyObjectRef,
vm: &rustpython_vm::VirtualMachine,
) -> rustpython_vm::PyResult<Out> {
use rumpy::{ArraysD, DType, PyNdArray};
if let Some(a) = obj.downcast_ref::<PyNdArray>() {
let f = a.view().cast(DType::F64);
let ArraysD::F64(x) = f else { unreachable!() };
return Ok(Out {
shape: x.shape().to_vec(),
data: x.iter().copied().collect(),
});
}
if obj.is(&vm.ctx.true_value) {
return Ok(Out {
shape: vec![],
data: vec![1.0],
});
}
if obj.is(&vm.ctx.false_value) {
return Ok(Out {
shape: vec![],
data: vec![0.0],
});
}
if let Ok(f) = obj.try_float(vm) {
return Ok(Out {
shape: vec![],
data: vec![f.to_f64()],
});
}
if let Some(l) = obj.downcast_ref::<RpyList>() {
let mut shape = Vec::new();
let mut data = Vec::new();
flatten(l, &mut shape, &mut data, vm, 0)?;
return Ok(Out { shape, data });
}
Err(vm.new_type_error(format!("bad result type {}", obj.class().name())))
}
fn flatten(
l: &RpyList,
shape: &mut Vec<usize>,
data: &mut Vec<f64>,
vm: &rustpython_vm::VirtualMachine,
depth: usize,
) -> rustpython_vm::PyResult<()> {
let items = l.borrow_vec();
if depth == shape.len() {
shape.push(items.len());
}
for it in items.iter() {
if let Some(s) = it.downcast_ref::<RpyList>() {
flatten(s, shape, data, vm, depth + 1)?;
} else {
data.push(it.try_float(vm)?.to_f64());
}
}
Ok(())
}
fn numpy_run(source: &str) -> Out {
Python::attach(|py| -> PyResult<Out> {
let g = pyo3::types::PyDict::new(py);
let np = PyModule::import(py, "numpy")?;
g.set_item("np", &np)?;
py.run(
&std::ffi::CString::new(source).unwrap(),
Some(&g),
None,
)?;
let result = g.get_item("result")?.unwrap();
let arr = np.getattr("asarray")?.call1((result,))?;
let shape: Vec<usize> = arr.getattr("shape")?.extract()?;
let flat = arr.call_method0("ravel")?.call_method0("tolist")?;
let data: Vec<f64> = flat
.cast::<PyList>()?
.iter()
.map(|x| {
if let Ok(b) = x.extract::<bool>() {
Ok(if b { 1.0 } else { 0.0 })
} else {
x.extract::<f64>()
}
})
.collect::<PyResult<_>>()?;
Ok(Out { shape, data })
})
.expect("numpy run failed")
}
fn assert_same(s: &str) {
let r = rumpy_run(s);
let n = numpy_run(s);
assert_eq!(r.shape, n.shape, "shape mismatch for snippet:\n{s}");
assert_eq!(r.data.len(), n.data.len(), "len mismatch:\n{s}");
for (a, b) in r.data.iter().zip(n.data.iter()) {
if a.is_nan() && b.is_nan() {
continue;
}
assert_abs_diff_eq!(*a, *b, epsilon = 1e-7);
}
}
#[test]
fn size_1_array_sum_mean() {
assert_same(
r#"
import numpy as np
a = np.array([42.0])
result = np.array([float(a.sum()), float(a.mean()), float(a.min()), float(a.max())])
"#,
);
}
#[test]
fn size_1_int_conversion() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([7.5])
result = np.array([float(int(a))])
"#,
);
assert_eq!(r.data, vec![7.0]);
}
#[test]
fn size_1_bool_conversion() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([5.0])
result = np.array([1.0 if bool(a) else 0.0])
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn zero_length_array_shape() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([], dtype="float64")
result = np.array([float(a.shape[0]), float(a.size)])
"#,
);
assert_eq!(r.data, vec![0.0, 0.0]);
}
#[test]
fn zero_length_array_arithmetic() {
assert_same(
r#"
import numpy as np
a = np.array([], dtype="float64")
b = a + 1.0
result = b
"#,
);
}
#[test]
fn inf_arithmetic() {
let r = rumpy_run(
r#"
import numpy as np
inf = float("inf")
a = np.array([1.0, inf, -inf, 0.0])
result = a + 1.0
"#,
);
assert_eq!(r.data[0], 2.0);
assert!(r.data[1].is_infinite() && r.data[1] > 0.0);
assert!(r.data[2].is_infinite() && r.data[2] < 0.0);
assert_eq!(r.data[3], 1.0);
}
#[test]
fn nan_comparison_returns_false() {
let r = rumpy_run(
r#"
import numpy as np
nan = float("nan")
a = np.array([nan, 1.0])
b = np.array([nan, 1.0])
result = (a == b).astype(int)
"#,
);
assert_eq!(r.data, vec![0.0, 1.0]);
}
#[test]
fn pos_neg_zero_equal() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([0.0, -0.0])
b = np.array([-0.0, 0.0])
result = (a == b).astype(int)
"#,
);
assert_eq!(r.data, vec![1.0, 1.0]);
}
#[test]
fn tiny_addition_preserves_precision() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([1e-300, 1e300])
b = a * a
result = b
"#,
);
assert_eq!(r.data[0], 0.0);
assert!(r.data[1].is_infinite());
}
#[test]
fn reshape_4d() {
assert_same(
r#"
import numpy as np
a = np.arange(120.0).reshape(2, 3, 4, 5)
result = a.reshape(6, -1)
"#,
);
}
#[test]
fn argmin_argmax_single_element() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([42.0])
result = np.array([np.argmin(a), np.argmax(a)]).astype(float)
"#,
);
assert_eq!(r.data, vec![0.0, 0.0]);
}
#[test]
fn argmax_returns_first_occurrence() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([1.0, 5.0, 5.0, 5.0, 2.0])
result = np.array([np.argmax(a)]).astype(float)
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn complex_magnitude_via_abs() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([3+4j, 6+8j, 0+0j])
result = np.abs(a)
"#,
);
assert!((r.data[0] - 5.0).abs() < 1e-9);
assert!((r.data[1] - 10.0).abs() < 1e-9);
assert!(r.data[2].abs() < 1e-9);
}
#[test]
fn complex_plus_real_array() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([1+0j, 2+0j])
b = np.array([10.0, 20.0])
c = a + b
result = np.array([c[0].real, c[1].real])
"#,
);
assert_eq!(r.data, vec![11.0, 22.0]);
}
#[test]
fn shape_is_tuple() {
let r = rumpy_run(
r#"
import numpy as np
a = np.zeros((2, 3, 4))
result = np.array([1.0 if isinstance(a.shape, tuple) else 0.0])
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn ndim_for_0d_1d_2d() {
let r = rumpy_run(
r#"
import numpy as np
a0 = np.array(42.0)
a1 = np.array([1.0, 2.0])
a2 = np.array([[1.0]])
result = np.array([a0.ndim, a1.ndim, a2.ndim]).astype(float)
"#,
);
assert_eq!(r.data, vec![0.0, 1.0, 2.0]);
}
#[test]
fn full_negative_value() {
assert_same(
r#"
import numpy as np
result = np.full((3,), -7.5)
"#,
);
}
#[test]
fn empty_correct_shape() {
let r = rumpy_run(
r#"
import numpy as np
a = np.empty((4, 5))
result = np.array([a.shape[0], a.shape[1]]).astype(float)
"#,
);
assert_eq!(r.data, vec![4.0, 5.0]);
}
#[test]
fn random_uniform_in_range() {
let r = rumpy_run(
r#"
import numpy as np
np.random.seed(0)
a = np.random.rand(100)
result = np.array([float(a.min()), float(a.max())])
"#,
);
assert!(r.data[0] >= 0.0);
assert!(r.data[1] < 1.0);
}
#[test]
fn random_randint_in_range() {
let r = rumpy_run(
r#"
import numpy as np
np.random.seed(7)
a = np.random.randint(0, 10, 100)
result = np.array([float(a.min()), float(a.max())])
"#,
);
assert!(r.data[0] >= 0.0 && r.data[1] < 10.0);
}
#[test]
fn tobytes_size_matches_nbytes() {
let r = rumpy_run(
r#"
import numpy as np
a = np.arange(12.0).reshape(3, 4)
b = a.tobytes()
result = np.array([len(b), a.nbytes]).astype(float)
"#,
);
assert_eq!(r.data[0], r.data[1]);
assert_eq!(r.data[0], 96.0); }
#[test]
fn itemsize_grid() {
let r = rumpy_run(
r#"
import numpy as np
sizes = [
np.zeros(1, dtype="int8").itemsize,
np.zeros(1, dtype="int32").itemsize,
np.zeros(1, dtype="float64").itemsize,
np.zeros(1, dtype="complex128").itemsize,
]
result = np.array(sizes).astype(float)
"#,
);
assert_eq!(r.data, vec![1.0, 4.0, 8.0, 16.0]);
}
#[test]
fn arith_chain() {
assert_same(
r#"
import numpy as np
a = np.arange(5.0)
result = (a + 1) * (a - 1) + 2 * a - 3
"#,
);
}
#[test]
fn arange_single_arg() {
assert_same(
r#"
import numpy as np
result = np.arange(7).astype("float64")
"#,
);
}
#[test]
fn linspace_descending() {
assert_same(
r#"
import numpy as np
result = np.linspace(10.0, -10.0, 11)
"#,
);
}
#[test]
fn bool_array_sum() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([True, True, False, True])
result = np.array([float(a.sum())])
"#,
);
assert_eq!(r.data, vec![3.0]);
}
#[test]
fn zero_first_dim() {
let r = rumpy_run(
r#"
import numpy as np
a = np.zeros((0, 5))
result = np.array([a.shape[0], a.shape[1], a.size]).astype(float)
"#,
);
assert_eq!(r.data, vec![0.0, 5.0, 0.0]);
}
#[test]
fn eye_rect() {
assert_same(
r#"
import numpy as np
result = np.eye(3, 5)
"#,
);
}
#[test]
fn nested_ufuncs() {
assert_same(
r#"
import numpy as np
a = np.linspace(0.0, 2*np.pi, 8)
result = np.sin(np.cos(a)) + np.log(np.abs(a) + 1)
"#,
);
}
#[test]
fn sum_3d_axis() {
assert_same(
r#"
import numpy as np
a = np.arange(24.0).reshape(2, 3, 4)
result = a.sum(axis=2)
"#,
);
}
#[test]
fn any_along_each_axis() {
assert_same(
r#"
import numpy as np
a = np.array([[1, 0, 0], [0, 0, 0], [0, 1, 0]])
result = np.array([np.any(a, axis=0).astype(int), np.any(a, axis=1).astype(int)])
"#,
);
}
#[test]
fn argsort_axis_1() {
assert_same(
r#"
import numpy as np
a = np.array([[3.0, 1.0, 4.0], [9.0, 2.0, 6.0]])
result = np.argsort(a, axis=1).astype("float64")
"#,
);
}
#[test]
fn size_for_empty() {
let r = rumpy_run(
r#"
import numpy as np
a = np.zeros((3, 0, 5))
result = np.array([a.size, a.shape[1]]).astype(float)
"#,
);
assert_eq!(r.data, vec![0.0, 0.0]);
}
#[test]
fn einsum_trace_3x3() {
assert_same(
r#"
import numpy as np
a = np.arange(9.0).reshape(3, 3)
result = np.array([float(np.einsum('ii->', a))])
"#,
);
}
#[test]
fn einsum_double_trace() {
assert_same(
r#"
import numpy as np
a = np.arange(16.0).reshape(2, 2, 2, 2)
result = np.array([float(np.einsum('iijj->', a))])
"#,
);
}
#[test]
fn einsum_diagonal() {
assert_same(
r#"
import numpy as np
a = np.arange(9.0).reshape(3, 3)
result = np.einsum('ii->i', a)
"#,
);
}
#[test]
fn vstack_three_arrays() {
assert_same(
r#"
import numpy as np
a = np.array([[1.0, 2.0], [3.0, 4.0]])
b = np.array([[5.0, 6.0]])
c = np.array([[7.0, 8.0], [9.0, 10.0]])
result = np.vstack([a, b, c])
"#,
);
}
#[test]
fn hstack_2d() {
assert_same(
r#"
import numpy as np
a = np.array([[1.0, 2.0], [3.0, 4.0]])
b = np.array([[5.0], [6.0]])
result = np.hstack([a, b])
"#,
);
}
#[test]
fn nonzero_returns_indices() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([0.0, 1.0, 0.0, 2.0, 0.0, 3.0])
result = np.nonzero(a)[0].astype("float64")
"#,
);
assert_eq!(r.data, vec![1.0, 3.0, 5.0]);
}
#[test]
fn tolist_then_array_round_trip() {
assert_same(
r#"
import numpy as np
a = np.array([1.0, 2.0, 3.0])
lst = a.tolist()
result = np.array(lst)
"#,
);
}
#[test]
fn repeat_2d_array() {
assert_same(
r#"
import numpy as np
a = np.array([[1.0, 2.0], [3.0, 4.0]])
result = np.repeat(a, 2)
"#,
);
}
#[test]
fn arange_int_args_int_dtype() {
let r = rumpy_run(
r#"
import numpy as np
a = np.arange(0, 5, 1)
result = np.array([a.dtype.kind == 'i' or a.dtype.kind == 'u']).astype(int).astype("float64")
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn conj_complex_array() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([1+2j, 3-4j, 0+5j])
b = np.conj(a)
# Compare imaginary parts (negation)
result = np.array([b[0].imag, b[1].imag, b[2].imag])
"#,
);
assert_eq!(r.data, vec![-2.0, 4.0, -5.0]);
}
#[test]
fn real_imag_pure_real() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([1.0, 2.0, 3.0])
result = np.array([np.real(a), np.imag(a)])
"#,
);
assert_eq!(r.shape, vec![2, 3]);
assert_eq!(r.data[3..], vec![0.0, 0.0, 0.0]); }
#[test]
fn broadcast_row_col_to_matrix() {
assert_same(
r#"
import numpy as np
a = np.array([[1.0, 2.0, 3.0]])
b = np.array([[10.0], [20.0], [30.0]])
result = a + b
"#,
);
}
#[test]
fn comparison_chain_with_bitwise() {
assert_same(
r#"
import numpy as np
a = np.arange(10.0)
mask = (a > 2) & (a < 7)
result = mask.astype(int)
"#,
);
}
#[test]
fn fancy_setitem_scalar() {
assert_same(
r#"
import numpy as np
a = np.arange(10.0)
a[[1, 3, 5]] = 99.0
result = a
"#,
);
}
#[test]
fn mask_setitem_scalar() {
assert_same(
r#"
import numpy as np
a = np.arange(10.0)
a[a % 2 == 0] = -1.0
result = a
"#,
);
}
#[test]
fn iand_int_array() {
assert_same(
r#"
import numpy as np
a = np.array([0xFF, 0xFF, 0xFF, 0xFF], dtype="int32")
a &= np.array([0x0F, 0xF0, 0xAA, 0x55], dtype="int32")
result = a.astype("float64")
"#,
);
}
#[test]
fn transpose_then_sum_3d() {
assert_same(
r#"
import numpy as np
a = np.arange(24.0).reshape(2, 3, 4)
result = a.T.sum(axis=0)
"#,
);
}
#[test]
fn array_times_transpose() {
assert_same(
r#"
import numpy as np
a = np.arange(1.0, 7.0).reshape(2, 3)
result = a * a.reshape(3, 2).T # same shape (2, 3) elementwise
"#,
);
}
#[test]
fn norm_1d_default() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([3.0, 4.0])
result = np.array([float(np.linalg.norm(a))])
"#,
);
assert!((r.data[0] - 5.0).abs() < 1e-9);
}
#[test]
fn equality_2d() {
assert_same(
r#"
import numpy as np
a = np.arange(6).reshape(2, 3)
b = np.array([[0, 1, 99], [3, 99, 5]])
result = (a == b).astype(int)
"#,
);
}
#[test]
fn chain_reshape_sum() {
assert_same(
r#"
import numpy as np
a = np.arange(12.0)
result = a.reshape(3, 4).sum(axis=1)
"#,
);
}
#[test]
fn iinfo_max_type() {
let r = rumpy_run(
r#"
import numpy as np
info = np.iinfo("int8")
result = np.array([info.max - info.min]).astype("float64")
"#,
);
assert_eq!(r.data, vec![255.0]);
}
#[test]
fn finfo_eps_relation() {
let r = rumpy_run(
r#"
import numpy as np
eps = float(np.finfo("float32").eps)
near = 1.0 + eps
far = 1.0 + eps * 0.4
result = np.array([1.0 if near > 1.0 else 0.0, 1.0 if far > 1.0 else 0.0])
"#,
);
assert_eq!(r.data[0], 1.0);
}
#[test]
fn broadcast_1_n_with_m_1() {
assert_same(
r#"
import numpy as np
a = np.arange(4.0).reshape(1, 4)
b = np.arange(3.0).reshape(3, 1)
result = a + b
"#,
);
}
#[test]
fn compare_broadcasts() {
assert_same(
r#"
import numpy as np
a = np.arange(6.0).reshape(2, 3)
result = (a > np.array([1.0, 2.0, 3.0])).astype(int)
"#,
);
}
#[test]
fn cumsum_empty() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([], dtype="float64")
result = np.cumsum(a)
"#,
);
assert_eq!(r.shape, vec![0]);
}
#[test]
fn mul_mixed_int_float() {
assert_same(
r#"
import numpy as np
a = np.array([1, 2, 3], dtype="int32")
b = np.array([1.5, 2.5, 3.5], dtype="float32")
result = (a * b).astype("float64")
"#,
);
}
#[test]
fn sort_then_argmin_is_zero() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0])
b = np.sort(a)
result = np.array([np.argmin(b), np.argmax(b)]).astype(float)
"#,
);
assert_eq!(r.data, vec![0.0, 7.0]);
}
#[test]
fn len_returns_first_axis() {
let r = rumpy_run(
r#"
import numpy as np
a = np.arange(24.0).reshape(4, 6)
result = np.array([float(len(a))])
"#,
);
assert_eq!(r.data, vec![4.0]);
}
#[test]
fn abs_builtin_vs_np() {
assert_same(
r#"
import numpy as np
a = np.array([-1.0, -2.5, 3.0])
result = abs(a) - np.abs(a)
"#,
);
}
#[test]
fn ptp_method_returns_scalar() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([1.0, 5.0, 3.0, 8.0])
result = np.array([float(a.ptp())])
"#,
);
assert_eq!(r.data, vec![7.0]);
}
#[test]
fn count_nonzero_2d() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([[1.0, 0.0, 0.0], [0.0, 2.0, 3.0]])
result = np.array([float(np.count_nonzero(a))])
"#,
);
assert_eq!(r.data, vec![3.0]);
}
#[test]
fn broadcast_all_size_1() {
assert_same(
r#"
import numpy as np
a = np.array([[5.0]])
b = np.array([2.0])
result = a * b
"#,
);
}
#[test]
fn mask_using_isnan() {
let r = rumpy_run(
r#"
import numpy as np
a = np.array([1.0, float('nan'), 3.0, float('nan'), 5.0])
mask = ~np.isnan(a)
result = a[mask]
"#,
);
assert_eq!(r.data, vec![1.0, 3.0, 5.0]);
}
#[test]
fn composite_filter_then_square() {
assert_same(
r#"
import numpy as np
a = np.arange(10.0)
mask = a > 5
result = a[mask] ** 2
"#,
);
}
#[test]
fn chain_ufuncs_with_broadcasting() {
assert_same(
r#"
import numpy as np
a = np.arange(6.0).reshape(2, 3)
b = np.array([1.0, 2.0, 3.0])
result = np.sin(a * b) + np.cos(a + b)
"#,
);
}