#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
use runmat_builtins::Value;
use runmat_runtime as rt;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squeeze_basic() {
let t = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).unwrap();
let v = rt::call_builtin("squeeze", &[Value::Tensor(t)]).unwrap();
match v {
Value::Tensor(tt) => {
assert_eq!(tt.shape, vec![2, 2]);
}
_ => panic!("expected tensor"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn permute_swap_dims() {
let t = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let ord = runmat_builtins::Tensor::new(vec![2.0, 1.0], vec![1, 2]).unwrap();
let v = rt::call_builtin("permute", &[Value::Tensor(t.clone()), Value::Tensor(ord)]).unwrap();
if let Value::Tensor(p) = v {
assert_eq!(p.shape, vec![3, 2]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn cat_dim1() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = runmat_builtins::Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let v = rt::call_builtin(
"cat",
&[Value::Num(1.0), Value::Tensor(a), Value::Tensor(b)],
)
.unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![4, 2]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn cat_variadic_three_inputs() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
let b = runmat_builtins::Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap();
let c = runmat_builtins::Tensor::new(vec![5.0, 6.0], vec![2, 1]).unwrap();
let v = rt::call_builtin(
"cat",
&[
Value::Num(2.0),
Value::Tensor(a),
Value::Tensor(b),
Value::Tensor(c),
],
)
.unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![2, 3]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repmat_2d() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let v = rt::call_builtin(
"repmat",
&[Value::Tensor(a), Value::Num(2.0), Value::Num(3.0)],
)
.unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![4, 6]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repmat_nd_vector_form() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
let reps = runmat_builtins::Tensor::new(vec![3.0, 4.0], vec![1, 2]).unwrap();
let v = rt::call_builtin("repmat", &[Value::Tensor(a), Value::Tensor(reps)]).unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![3, 8]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repelem_row_vector_scalar_form() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let v = rt::call_builtin("repelem", &[Value::Tensor(a), Value::Num(2.0)]).unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![1, 6]);
assert_eq!(t.data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repelem_row_vector_count_form() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let counts = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let v = rt::call_builtin("repelem", &[Value::Tensor(a), Value::Tensor(counts)]).unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![1, 6]);
assert_eq!(t.data, vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repelem_matrix_block_form() {
let magic = runmat_builtins::Tensor::new(
vec![8.0, 3.0, 4.0, 1.0, 5.0, 9.0, 6.0, 7.0, 2.0],
vec![3, 3],
)
.unwrap();
let v = rt::call_builtin(
"repelem",
&[Value::Tensor(magic), Value::Num(2.0), Value::Num(3.0)],
)
.unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![6, 9]);
assert_eq!(t.data[0], 8.0);
assert_eq!(t.data[1], 8.0);
assert_eq!(t.data[6], 8.0);
assert_eq!(t.data[7], 8.0);
assert_eq!(t.data[8], 3.0);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repelem_empty_column_preserves_orientation() {
let a = runmat_builtins::Tensor::new(Vec::new(), vec![0, 1]).unwrap();
let v = rt::call_builtin("repelem", &[Value::Tensor(a), Value::Num(3.0)]).unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![0, 1]);
assert!(t.data.is_empty());
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repelem_single_factor_uses_unique_nd_non_singleton_axis() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 1, 3]).unwrap();
let v = rt::call_builtin("repelem", &[Value::Tensor(a), Value::Num(2.0)]).unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![1, 1, 6]);
assert_eq!(t.data, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn repelem_cell_array_nd() {
let cell = runmat_builtins::CellArray::new_with_shape(
vec![Value::Num(1.0), Value::Num(2.0)],
vec![1, 1, 2],
)
.unwrap();
let v = rt::call_builtin(
"repelem",
&[
Value::Cell(cell),
Value::Num(1.0),
Value::Num(1.0),
Value::Num(2.0),
],
)
.unwrap();
if let Value::Cell(out) = v {
assert_eq!(out.shape, vec![1, 1, 4]);
let values: Vec<f64> = out
.data
.iter()
.map(|ptr| match unsafe { &*ptr.as_raw() } {
Value::Num(n) => *n,
other => panic!("expected numeric cell element, got {other:?}"),
})
.collect();
assert_eq!(values, vec![1.0, 1.0, 2.0, 2.0]);
} else {
panic!("expected cell")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linspace_basic() {
let v = rt::call_builtin(
"linspace",
&[
Value::Num(0.0),
Value::Num(1.0),
Value::Int(runmat_builtins::IntValue::I32(5)),
],
)
.unwrap();
if let Value::Tensor(t) = v {
assert_eq!(t.shape, vec![1, 5]);
assert!((t.data[4] - 1.0).abs() < 1e-9);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn meshgrid_basic() {
let x = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let y = runmat_builtins::Tensor::new(vec![10.0, 20.0], vec![2, 1]).unwrap();
let v = rt::call_builtin("meshgrid", &[Value::Tensor(x), Value::Tensor(y)]).unwrap();
if let Value::Tensor(t) = v {
assert!(t.shape == vec![2, 3] || t.shape == vec![3, 2]);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn diag_vector_to_matrix_and_back() {
let v = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
let m = rt::call_builtin("diag", &[Value::Tensor(v)]).unwrap();
if let Value::Tensor(mt) = &m {
assert_eq!(mt.shape, vec![3, 3]);
}
let d = rt::call_builtin("diag", &[m]).unwrap();
if let Value::Tensor(dt) = d {
assert_eq!(dt.shape, vec![3, 1]);
assert!((dt.data[2] - 3.0).abs() < 1e-9);
} else {
panic!("expected tensor")
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn triu_tril_shapes() {
let a = runmat_builtins::Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let u = rt::call_builtin("triu", &[Value::Tensor(a.clone())]).unwrap();
let l = rt::call_builtin("tril", &[Value::Tensor(a)]).unwrap();
if let Value::Tensor(ut) = u {
assert_eq!(ut.shape, vec![2, 2]);
} else {
panic!("expected tensor")
}
if let Value::Tensor(lt) = l {
assert_eq!(lt.shape, vec![2, 2]);
} else {
panic!("expected tensor")
}
}