use ariadnetor_tensor::{ComputeBackendTensorExt, Host};
use ariadnetor_tensor::{DenseTensor, MemoryOrder};
fn logical_2x3_row_major() -> DenseTensor<f64> {
logical_2x3_column_major().reordered(MemoryOrder::RowMajor)
}
fn logical_2x3_column_major() -> DenseTensor<f64> {
Host::shared().dense(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3])
}
#[test]
fn split_leg_row_major_logical_grouping() {
let t = Host::shared()
.dense(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])
.reordered(MemoryOrder::RowMajor);
let r = t.split_leg(0, &[2, 3]);
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.order(), MemoryOrder::RowMajor);
assert_eq!(r.data_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn split_leg_column_major_logical_grouping() {
let t = Host::shared().dense(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
let r = t.split_leg(0, &[2, 3]);
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.order(), MemoryOrder::ColumnMajor);
assert_eq!(r.data_slice(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn fuse_legs_row_major_logical_flatten() {
let r = logical_2x3_row_major().fuse_legs(0..2);
assert_eq!(r.shape(), &[6]);
assert_eq!(r.order(), MemoryOrder::RowMajor);
assert_eq!(r.data_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn fuse_legs_column_major_logical_flatten() {
let r = logical_2x3_column_major().fuse_legs(0..2);
assert_eq!(r.shape(), &[6]);
assert_eq!(r.order(), MemoryOrder::ColumnMajor);
assert_eq!(r.data_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn reshape_logical_column_major_regroup() {
let r = logical_2x3_column_major().reshape_logical(vec![3, 2]);
assert_eq!(r.shape(), &[3, 2]);
assert_eq!(r.order(), MemoryOrder::ColumnMajor);
assert_eq!(r.data_slice(), &[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
}
#[test]
fn fuse_then_split_round_trips_column_major() {
let data: Vec<f64> = (0..24).map(|x| x as f64).collect();
let t = Host::shared().dense(data.clone(), vec![2, 3, 4]);
let fused = t.fuse_legs(0..2);
assert_eq!(fused.shape(), &[6, 4]);
let restored = fused.split_leg(0, &[2, 3]);
assert_eq!(restored.shape(), &[2, 3, 4]);
assert_eq!(restored.order(), MemoryOrder::ColumnMajor);
assert_eq!(restored.data_slice(), data.as_slice());
}
fn check_reshape_logical_multi_group(order: MemoryOrder) {
let total = 2 * 3 * 2 * 4 * 5;
let data: Vec<f64> = (0..total).map(|x| x as f64).collect();
let t = Host::shared()
.dense(data, vec![2, 3, 2, 4, 5])
.reordered(order);
let via_reshape = t.reshape_logical(vec![6, 2, 20]);
let via_chain = t.fuse_legs(0..2).fuse_legs(2..4);
assert_eq!(via_reshape.shape(), &[6, 2, 20]);
assert_eq!(via_reshape.order(), order);
assert_eq!(via_reshape.data_slice(), via_chain.data_slice());
}
#[test]
fn reshape_logical_matches_chained_fuse_multi_group() {
check_reshape_logical_multi_group(MemoryOrder::ColumnMajor);
check_reshape_logical_multi_group(MemoryOrder::RowMajor);
}
#[test]
fn fuse_single_axis_is_logical_identity() {
let t = logical_2x3_column_major();
let r = t.fuse_legs(1..2);
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.order(), MemoryOrder::ColumnMajor);
assert_eq!(r.data_slice(), t.data_slice());
}
#[test]
#[should_panic(expected = "fuse_legs")]
fn fuse_legs_panics_on_out_of_range() {
let _ = logical_2x3_row_major().fuse_legs(0..5);
}
#[test]
#[should_panic(expected = "fuse_legs")]
fn fuse_legs_panics_on_empty_range() {
let _ = logical_2x3_row_major().fuse_legs(1..1);
}
#[test]
#[should_panic(expected = "split_leg")]
fn split_leg_panics_on_empty_into() {
let _ = logical_2x3_row_major().split_leg(0, &[]);
}
#[test]
#[should_panic(expected = "split_leg")]
fn split_leg_panics_on_product_mismatch() {
let _ = logical_2x3_row_major().split_leg(0, &[2, 2]);
}
#[test]
#[should_panic(expected = "split_leg")]
fn split_leg_panics_on_axis_out_of_range() {
let _ = logical_2x3_row_major().split_leg(2, &[1]);
}