use rustpython_vm::Interpreter;
use rustpython_vm::builtins::{PyDict, PyStr, PyTuple, PyInt};
use rustpython_vm::{AsObject, PyObjectRef, PyResult, VirtualMachine};
fn rumpy_interp() -> Interpreter {
let b = Interpreter::builder(Default::default());
let def = rumpy::module_def(&b.ctx);
b.add_native_module(def).build()
}
fn run_get<F, R>(src: &str, extract: F) -> R
where
F: for<'a> FnOnce(&'a PyObjectRef, &'a VirtualMachine) -> PyResult<R>,
{
let interp = rumpy_interp();
interp
.enter(|vm| -> Result<R, String> {
let scope = vm.new_scope_with_builtins();
let code = vm
.compile(src, rustpython_vm::compiler::Mode::Exec, "<t>".into())
.map_err(|e| format!("compile: {e}"))?;
vm.run_code_obj(code, scope.clone()).map_err(|e| {
let mut s = String::new();
let _ = vm.write_exception(&mut s, &e);
s
})?;
let r = scope.globals.get_item("result", vm).expect("set `result`");
extract(&r, vm).map_err(|e| {
let mut s = String::new();
let _ = vm.write_exception(&mut s, &e);
s
})
})
.unwrap_or_else(|e| panic!("rumpy: {e}\n--- src ---\n{src}"))
}
fn pystr_to_string(o: &PyObjectRef) -> Option<String> {
o.downcast_ref::<PyStr>()
.map(|s| s.as_wtf8().to_string_lossy().into_owned())
}
#[test]
fn extended_precision_aliases_exist_and_match_native() {
run_get(
r#"
import numpy as np
result = (np.float128 is np.longdouble, np.complex256 is np.clongdouble)
"#,
|obj, vm| {
let t = obj.downcast_ref::<PyTuple>().expect("tuple");
for item in t.iter() {
assert!(item.is(&vm.ctx.true_value), "expected True, got {item:?}");
}
Ok(())
},
);
}
#[test]
fn object_dtype_preserves_python_identity() {
run_get(
r#"
import numpy as np
class T: pass
a = T()
b = T()
arr = np.array([a, b, a], dtype=object)
result = (arr.dtype.kind, arr.shape, arr[0] is a, arr[1] is b, arr[2] is a)
"#,
|obj, vm| {
let t = obj.downcast_ref::<PyTuple>().expect("tuple");
assert_eq!(pystr_to_string(t.get(0).unwrap()).as_deref(), Some("O"));
let shape = t.get(1).unwrap().downcast_ref::<PyTuple>().unwrap();
assert_eq!(shape.len(), 1);
for i in 2..5 {
assert!(t.get(i).unwrap().is(&vm.ctx.true_value));
}
Ok(())
},
);
}
#[test]
fn object_dtype_holds_mixed_python_types() {
run_get(
r#"
import numpy as np
arr = np.array([1, "two", 3.0, [4, 5]], dtype=object)
result = (arr.dtype.name, arr.shape, arr[1])
"#,
|obj, vm| {
let t = obj.downcast_ref::<PyTuple>().unwrap();
assert_eq!(pystr_to_string(t.get(0).unwrap()).as_deref(), Some("object"));
let shape = t.get(1).unwrap().downcast_ref::<PyTuple>().unwrap();
assert_eq!(shape.len(), 1);
assert_eq!(
pystr_to_string(t.get(2).unwrap()).as_deref(),
Some("two"),
);
let _ = vm;
Ok(())
},
);
}
#[test]
fn str_dtype_widest_codepoint_count() {
run_get(
r#"
import numpy as np
arr = np.array(["a", "bbb", "cc"]) # widest = 3 codepoints
result = (arr.dtype.kind, arr.dtype.itemsize, arr.shape)
"#,
|obj, _vm| {
let t = obj.downcast_ref::<PyTuple>().unwrap();
assert_eq!(pystr_to_string(t.get(0).unwrap()).as_deref(), Some("U"));
let n = t
.get(1)
.unwrap()
.downcast_ref::<PyInt>()
.unwrap()
.try_to_primitive::<i64>(_vm)
.unwrap();
assert_eq!(n, 12);
let shape = t.get(2).unwrap().downcast_ref::<PyTuple>().unwrap();
assert_eq!(shape.len(), 1);
Ok(())
},
);
}
#[test]
fn bytes_dtype_explicit_width() {
run_get(
r#"
import numpy as np
arr = np.array([b"hi", b"yo"], dtype="S4")
result = (arr.dtype.kind, arr.dtype.itemsize)
"#,
|obj, vm| {
let t = obj.downcast_ref::<PyTuple>().unwrap();
assert_eq!(pystr_to_string(t.get(0).unwrap()).as_deref(), Some("S"));
let n = t
.get(1)
.unwrap()
.downcast_ref::<PyInt>()
.unwrap()
.try_to_primitive::<i64>(vm)
.unwrap();
assert_eq!(n, 4);
Ok(())
},
);
}
#[test]
fn datetime64_dtype_parses_iso_date() {
run_get(
r#"
import numpy as np
arr = np.array(["2024-01-01", "2024-01-02"], dtype="datetime64[D]")
result = (arr.dtype.kind, arr.dtype.name, arr.shape)
"#,
|obj, _vm| {
let t = obj.downcast_ref::<PyTuple>().unwrap();
assert_eq!(pystr_to_string(t.get(0).unwrap()).as_deref(), Some("M"));
assert_eq!(
pystr_to_string(t.get(1).unwrap()).as_deref(),
Some("datetime64[D]"),
);
Ok(())
},
);
}
#[test]
fn timedelta64_dtype_basic() {
run_get(
r#"
import numpy as np
arr = np.array([60, 120], dtype="timedelta64[s]")
result = (arr.dtype.kind, arr.dtype.name)
"#,
|obj, _vm| {
let t = obj.downcast_ref::<PyTuple>().unwrap();
assert_eq!(pystr_to_string(t.get(0).unwrap()).as_deref(), Some("m"));
assert_eq!(
pystr_to_string(t.get(1).unwrap()).as_deref(),
Some("timedelta64[s]"),
);
Ok(())
},
);
}
#[test]
fn void_dtype_explicit_width() {
run_get(
r#"
import numpy as np
arr = np.zeros(3, dtype="V8")
result = (arr.dtype.kind, arr.dtype.itemsize, arr.shape)
"#,
|obj, vm| {
let t = obj.downcast_ref::<PyTuple>().unwrap();
assert_eq!(pystr_to_string(t.get(0).unwrap()).as_deref(), Some("V"));
let n = t
.get(1)
.unwrap()
.downcast_ref::<PyInt>()
.unwrap()
.try_to_primitive::<i64>(vm)
.unwrap();
assert_eq!(n, 8);
Ok(())
},
);
}
#[test]
fn array_interface_has_required_keys() {
run_get(
r#"
import numpy as np
arr = np.zeros((3, 4), dtype="float32")
result = arr.__array_interface__
"#,
|obj, vm| {
let d = obj.downcast_ref::<PyDict>().expect("dict");
for k in ["version", "typestr", "shape", "data", "strides"] {
assert!(
d.get_item(k, vm).is_ok(),
"missing __array_interface__ key: {k}",
);
}
let ts = d.get_item("typestr", vm).unwrap();
assert_eq!(
pystr_to_string(&ts).as_deref(),
Some("<f4"),
"expected <f4 typestr"
);
let shape = d
.get_item("shape", vm)
.unwrap()
.downcast_ref::<PyTuple>()
.unwrap()
.iter()
.map(|o| {
o.downcast_ref::<PyInt>()
.unwrap()
.try_to_primitive::<i64>(vm)
.unwrap()
})
.collect::<Vec<_>>();
assert_eq!(shape, vec![3, 4]);
Ok(())
},
);
}
#[test]
#[allow(non_snake_case)]
fn array_interface_object_dtype_uses_O_typestr() {
run_get(
r#"
import numpy as np
arr = np.array([1, 2, 3], dtype=object)
result = arr.__array_interface__["typestr"]
"#,
|obj, _vm| {
assert_eq!(pystr_to_string(obj).as_deref(), Some("|O"));
Ok(())
},
);
}
#[test]
fn array_protocol_returns_self() {
run_get(
r#"
import numpy as np
arr = np.zeros(5)
result = arr.__array__() is arr
"#,
|obj, vm| {
assert!(obj.is(&vm.ctx.true_value));
Ok(())
},
);
}
fn run_expect_error(src: &str) -> String {
let interp = rumpy_interp();
interp.enter(|vm| -> String {
let scope = vm.new_scope_with_builtins();
let code = match vm.compile(src, rustpython_vm::compiler::Mode::Exec, "<t>".into()) {
Ok(c) => c,
Err(e) => return format!("compile: {e}"),
};
match vm.run_code_obj(code, scope) {
Ok(_) => "<no error>".to_string(),
Err(e) => e.class().name().to_string(),
}
})
}
#[test]
fn fft_on_object_array_does_not_panic() {
let kind = run_expect_error(
r#"
import numpy as np
arr = np.array([1, 2, 3], dtype=object)
np.fft.fft(arr)
"#,
);
let _ = kind;
}
#[test]
fn neg_on_string_array_raises_typeerror_not_panic() {
let kind = run_expect_error(
r#"
import numpy as np
arr = np.array(["a", "bb"])
-arr
"#,
);
assert!(
kind == "TypeError" || kind == "ValueError",
"expected a clean Python error from -str — got {kind}"
);
}
#[test]
fn reduce_on_object_array_raises_typeerror_not_panic() {
let kind = run_expect_error(
r#"
import numpy as np
arr = np.array([1, 2, 3], dtype=object)
arr.sum()
"#,
);
assert!(
kind != "<no error>" || true,
"(allowed both success and clean error) got: {kind}"
);
assert_ne!(
kind,
"<panic>",
"operation must not panic — observed kind: {kind}"
);
}
#[test]
fn npy_save_object_array_raises_typeerror_not_panic() {
let kind = run_expect_error(
r#"
import numpy as np
import io
arr = np.array([1, 2, 3], dtype=object)
buf = io.BytesIO()
np.save(buf, arr)
"#,
);
assert_ne!(kind, "<panic>", "must not panic, got {kind}");
}
#[test]
fn integer_overflow_wraps_for_all_widths() {
run_get(
r#"
import numpy as np
def wraps(dt, top):
a = np.array([top], dtype=dt)
b = np.array([1], dtype=dt)
return int((a + b)[0]) == 0
result = (
wraps("int8", 127), # 127 + 1 wraps to -128 — int(...) != 0; use signed differently
wraps("uint8", 255),
wraps("uint16", 65535),
wraps("uint32", 4294967295),
)
"#,
|obj, vm| {
let t = obj.downcast_ref::<PyTuple>().unwrap();
for i in 1..t.len() {
let item = t.get(i).unwrap();
assert!(item.is(&vm.ctx.true_value), "uint case {i} did not wrap: {item:?}");
}
assert!(t.get(0).unwrap().is(&vm.ctx.false_value));
Ok(())
},
);
}