use approx::assert_abs_diff_eq;
use rustpython_vm::{AsObject, Interpreter, builtins::PyList as RpyList};
fn rumpy_interp() -> Interpreter {
let b = Interpreter::builder(Default::default());
let def = rumpy::module_def(&b.ctx);
b.add_native_module(def).build()
}
#[derive(Debug)]
struct Out {
#[allow(dead_code)]
shape: Vec<usize>,
data: Vec<f64>,
}
fn rumpy_run(source: &str) -> Out {
let interp = rumpy_interp();
interp
.enter(|vm| -> Result<Out, String> {
let scope = vm.new_scope_with_builtins();
let code = vm
.compile(source, rustpython_vm::compiler::Mode::Exec, "<t>".into())
.map_err(|e| format!("compile: {e}"))?;
vm.run_code_obj(code, scope.clone())
.map_err(|e| pyerr(vm, &e))?;
let r = scope.globals.get_item("result", vm).expect("set result");
extract(&r, vm).map_err(|e| pyerr(vm, &e))
})
.unwrap_or_else(|e| panic!("rumpy: {e}\n--- src ---\n{source}"))
}
fn pyerr(
vm: &rustpython_vm::VirtualMachine,
e: &rustpython_vm::PyRef<rustpython_vm::builtins::PyBaseException>,
) -> String {
let mut s = String::new();
let _ = vm.write_exception(&mut s, e);
s
}
fn extract(
obj: &rustpython_vm::PyObjectRef,
vm: &rustpython_vm::VirtualMachine,
) -> rustpython_vm::PyResult<Out> {
use rumpy::{ArraysD, DType, PyNdArray};
if let Some(a) = obj.downcast_ref::<PyNdArray>() {
let f = a.view().cast(DType::F64);
let ArraysD::F64(x) = f else { unreachable!() };
return Ok(Out {
shape: x.shape().to_vec(),
data: x.iter().copied().collect(),
});
}
if obj.is(&vm.ctx.true_value) { return Ok(Out { shape: vec![], data: vec![1.0] }); }
if obj.is(&vm.ctx.false_value) { return Ok(Out { shape: vec![], data: vec![0.0] }); }
if let Ok(f) = obj.try_float(vm) {
return Ok(Out { shape: vec![], data: vec![f.to_f64()] });
}
if let Some(l) = obj.downcast_ref::<RpyList>() {
let mut shape = Vec::new();
let mut data = Vec::new();
for it in l.borrow_vec().iter() {
data.push(it.try_float(vm)?.to_f64());
}
shape.push(data.len());
return Ok(Out { shape, data });
}
Err(vm.new_type_error(format!("bad result {}", obj.class().name())))
}
#[test]
fn float32_constructor() {
let r = rumpy_run(
r#"
import numpy as np
x = np.float32(3.14)
result = np.array([float(x)])
"#,
);
assert_abs_diff_eq!(r.data[0], 3.14, epsilon = 1e-5);
}
#[test]
fn int32_constructor_truncates() {
let r = rumpy_run(
r#"
import numpy as np
x = np.int32(3.7)
result = np.array([float(x)])
"#,
);
assert_eq!(r.data, vec![3.0]);
}
#[test]
fn uint8_constructor_in_range() {
let r = rumpy_run(
r#"
import numpy as np
x = np.uint8(200)
result = np.array([float(x)])
"#,
);
assert_eq!(r.data, vec![200.0]);
}
#[test]
fn isinstance_int32_is_integer() {
let r = rumpy_run(
r#"
import numpy as np
x = np.int32(5)
result = np.array([1.0 if isinstance(x, np.integer) else 0.0])
"#,
);
let _ = r; }
#[test]
fn float64_dtype_name() {
let r = rumpy_run(
r#"
import numpy as np
x = np.float64(2.5)
result = np.array([float(x)])
"#,
);
assert_eq!(r.data, vec![2.5]);
}
#[test]
fn issubclass_int32_integer() {
let r = rumpy_run(
r#"
import numpy as np
ok1 = issubclass(np.int32, np.integer)
ok2 = issubclass(np.int32, np.signedinteger)
ok3 = issubclass(np.uint8, np.unsignedinteger)
ok4 = issubclass(np.float32, np.floating)
ok5 = issubclass(np.float32, np.inexact)
ok6 = issubclass(np.complex64, np.complexfloating)
ok7 = issubclass(np.bool_, np.generic)
result = np.array([1.0 if all([ok1, ok2, ok3, ok4, ok5, ok6, ok7]) else 0.0])
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn issubclass_int_not_float() {
let r = rumpy_run(
r#"
import numpy as np
ok1 = not issubclass(np.int32, np.floating)
ok2 = not issubclass(np.float64, np.integer)
result = np.array([1.0 if (ok1 and ok2) else 0.0])
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn abstract_cannot_instantiate() {
let r = rumpy_run(
r#"
import numpy as np
try:
np.integer(5)
ok = False
except TypeError:
ok = True
result = np.array([1.0 if ok else 0.0])
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn aliases_intp() {
let r = rumpy_run(
r#"
import numpy as np
ok = np.intp is np.int64
result = np.array([1.0 if ok else 0.0])
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn sctypedict_lookup() {
let r = rumpy_run(
r#"
import numpy as np
ok = np.sctypeDict["int32"] is np.int32 and np.sctypeDict["float64"] is np.float64
result = np.array([1.0 if ok else 0.0])
"#,
);
assert_eq!(r.data, vec![1.0]);
}
#[test]
fn complex64_constructor() {
let r = rumpy_run(
r#"
import numpy as np
x = np.complex64(3+4j)
# 0-D complex → extract its real part
result = np.array([float(x[()].real) if False else x.tolist().real])
"#,
);
let _ = r; }