use std::panic::{self, AssertUnwindSafe};
use std::sync::Arc;
use cutile;
use cutile::api;
use cutile::tensor::{IntoPartition, Reshape, ToHostVec};
use cutile::tile_kernel::DeviceOp;
mod common;
#[cutile::module]
mod tensor_views_module {
use cutile::core::*;
#[cutile::entry()]
fn passthrough_1d(output: &mut Tensor<f32, { [4] }>, input: &Tensor<f32, { [-1] }>) {
let tile: Tile<f32, { [4] }> = load_tile_like_1d(input, output);
output.store(tile);
}
#[cutile::entry()]
fn passthrough_2d(output: &mut Tensor<f32, { [2, 2] }>, input: &Tensor<f32, { [-1, -1] }>) {
let tile: Tile<f32, { [2, 2] }> = load_tile_like_2d(input, output);
output.store(tile);
}
#[cutile::entry()]
fn passthrough_1d_2(output: &mut Tensor<f32, { [2] }>, input: &Tensor<f32, { [-1] }>) {
let tile: Tile<f32, { [2] }> = load_tile_like_1d(input, output);
output.store(tile);
}
#[cutile::entry()]
fn copy_4(output: &mut Tensor<f32, { [4] }>, input: &Tensor<f32, { [-1] }>) {
let tile: Tile<f32, { [4] }> = load_tile_like_1d(input, output);
output.store(tile);
}
#[cutile::entry()]
fn copy_2x2(output: &mut Tensor<f32, { [2, 2] }>, input: &Tensor<f32, { [-1, -1] }>) {
let tile: Tile<f32, { [2, 2] }> = load_tile_like_2d(input, output);
output.store(tile);
}
}
#[test]
fn arc_views_are_zero_copy() {
common::with_test_stack(|| {
let base = Arc::new(api::arange::<f32>(8).sync().expect("Failed."));
let view = base.reshape(&[2, 4]).unwrap();
let dyn_view = base.reshape(&[4, 2]).unwrap();
let flat = base.reshape(&[8]).unwrap();
assert_eq!(
base.device_pointer().cu_deviceptr(),
view.device_pointer().cu_deviceptr()
);
assert_eq!(
base.device_pointer().cu_deviceptr(),
dyn_view.device_pointer().cu_deviceptr()
);
assert_eq!(
base.device_pointer().cu_deviceptr(),
flat.device_pointer().cu_deviceptr()
);
assert_eq!(view.shape(), vec![2, 4]);
assert_eq!(view.strides(), vec![4, 1]);
assert_eq!(dyn_view.shape(), vec![4, 2]);
assert_eq!(dyn_view.strides(), vec![2, 1]);
assert_eq!(flat.shape(), vec![8]);
assert_eq!(flat.strides(), vec![1]);
let base_host: Vec<f32> = (&base).to_host_vec().sync().expect("Failed.");
let flat_host: Vec<f32> = flat.to_host_vec().sync().expect("Failed.");
assert_eq!(base_host, flat_host);
let copied = base.dup().sync().expect("Failed.");
assert_ne!(
base.device_pointer().cu_deviceptr(),
copied.device_pointer().cu_deviceptr()
);
});
}
#[test]
fn invalid_view_shape_is_rejected() {
common::with_test_stack(|| {
let base = Arc::new(api::arange::<f32>(8).sync().expect("Failed."));
assert!(base.reshape(&[5]).is_err());
assert!(base.reshape(&[2, 2]).is_err());
});
}
#[test]
fn shared_storage_blocks_mutable_partition() {
common::with_test_stack(|| {
let base = Arc::new(api::arange::<f32>(8).sync().expect("Failed."));
let _view = base.reshape(&[2, 4]).unwrap();
let owned = Arc::try_unwrap(base).expect("Expected unique outer Arc.");
let result = panic::catch_unwind(AssertUnwindSafe(|| {
let _ = owned.partition([8]);
}));
assert!(result.is_err(), "Expected mutable partition to be rejected");
});
}
#[test]
fn arc_views_work_with_different_rank_kernels() {
common::with_test_stack(|| {
let base = Arc::new(api::arange::<f32>(4).sync().expect("Failed."));
let input_1d = base.reshape(&[4]).unwrap();
let input_2d = base.reshape(&[2, 2]).unwrap();
let output_1d = api::zeros::<f32>(&[4]).sync().expect("Failed.");
let (result_1d, _input_1d) =
tensor_views_module::passthrough_1d(output_1d.partition([4]), input_1d)
.sync()
.expect("Failed.");
let host_1d: Vec<f32> = result_1d
.unpartition()
.to_host_vec()
.sync()
.expect("Failed.");
let output_2d = api::zeros::<f32>(&[2, 2]).sync().expect("Failed.");
let (result_2d, _input_2d) =
tensor_views_module::passthrough_2d(output_2d.partition([2, 2]), input_2d)
.sync()
.expect("Failed.");
let host_2d: Vec<f32> = result_2d
.unpartition()
.to_host_vec()
.sync()
.expect("Failed.");
assert_eq!(host_1d, vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(host_2d, vec![0.0, 1.0, 2.0, 3.0]);
});
}
#[test]
fn slice_1d_offset() {
common::with_test_stack(|| {
let tensor = api::arange::<f32>(8).sync().expect("Failed.");
let sliced = tensor.slice(&[4..8]).expect("Failed.");
assert_eq!(sliced.shape(), &[4]);
assert_eq!(sliced.strides(), &[1]);
let output = api::zeros::<f32>(&[4]).sync().expect("Failed.");
let (result, _) = tensor_views_module::copy_4(output.partition([4]), &sliced)
.sync()
.expect("Failed.");
let host: Vec<f32> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert_eq!(host, vec![4.0, 5.0, 6.0, 7.0]);
});
}
#[test]
fn slice_1d_middle() {
common::with_test_stack(|| {
let tensor = api::arange::<f32>(8).sync().expect("Failed.");
let sliced = tensor.slice(&[2..6]).expect("Failed.");
assert_eq!(sliced.shape(), &[4]);
let output = api::zeros::<f32>(&[4]).sync().expect("Failed.");
let (result, _) = tensor_views_module::copy_4(output.partition([4]), &sliced)
.sync()
.expect("Failed.");
let host: Vec<f32> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert_eq!(host, vec![2.0, 3.0, 4.0, 5.0]);
});
}
#[test]
fn slice_2d_rows_only() {
common::with_test_stack(|| {
let tensor = api::arange::<f32>(16).sync().expect("Failed.");
let tensor_2d = tensor.view(&[4, 4]).expect("Failed.");
let sliced = tensor_2d.slice(&[1..3, 0..2]).expect("Failed.");
assert_eq!(sliced.shape(), &[2, 2]);
assert_eq!(sliced.strides(), &[4, 1]);
let output = api::zeros::<f32>(&[2, 2]).sync().expect("Failed.");
let (result, _) = tensor_views_module::copy_2x2(output.partition([2, 2]), &sliced)
.sync()
.expect("Failed.");
let host: Vec<f32> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert_eq!(host, vec![4.0, 5.0, 8.0, 9.0]);
});
}
#[test]
fn slice_2d_columns_non_contiguous() {
common::with_test_stack(|| {
let tensor = api::arange::<f32>(16).sync().expect("Failed.");
let tensor_2d = tensor.view(&[4, 4]).expect("Failed.");
let sliced = tensor_2d.slice(&[0..2, 2..4]).expect("Failed.");
assert_eq!(sliced.shape(), &[2, 2]);
assert_eq!(sliced.strides(), &[4, 1]);
let output = api::zeros::<f32>(&[2, 2]).sync().expect("Failed.");
let (result, _) = tensor_views_module::copy_2x2(output.partition([2, 2]), &sliced)
.sync()
.expect("Failed.");
let host: Vec<f32> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert_eq!(host, vec![2.0, 3.0, 6.0, 7.0]);
});
}
#[test]
fn slice_chained() {
common::with_test_stack(|| {
let tensor = api::arange::<f32>(16).sync().expect("Failed.");
let tensor_2d = tensor.view(&[4, 4]).expect("Failed.");
let row_slice = tensor_2d.slice(&[1..3]).expect("Failed.");
let final_slice = row_slice.slice(&[0..2, 1..3]).expect("Failed.");
assert_eq!(final_slice.shape(), &[2, 2]);
assert_eq!(final_slice.strides(), &[4, 1]);
let output = api::zeros::<f32>(&[2, 2]).sync().expect("Failed.");
let (result, _) = tensor_views_module::copy_2x2(output.partition([2, 2]), &final_slice)
.sync()
.expect("Failed.");
let host: Vec<f32> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert_eq!(host, vec![5.0, 6.0, 9.0, 10.0]);
});
}
#[test]
fn slice_out_of_bounds_rejected() {
common::with_test_stack(|| {
let tensor = api::arange::<f32>(8).sync().expect("Failed.");
assert!(tensor.slice(&[0..9]).is_err());
assert!(tensor.slice(&[5..3]).is_err());
let tensor_2d = tensor.view(&[2, 4]).expect("Failed.");
assert!(tensor_2d.slice(&[0..3]).is_err()); assert!(tensor_2d.slice(&[0..2, 0..5]).is_err()); });
}
#[test]
fn slice_preserves_view_reshape_rejection() {
common::with_test_stack(|| {
let tensor = api::arange::<f32>(16).sync().expect("Failed.");
let tensor_2d = tensor.view(&[4, 4]).expect("Failed.");
let col_slice = tensor_2d.slice(&[0..4, 1..3]).expect("Failed.");
assert!(
col_slice.view(&[8]).is_err(),
"reshape on non-contiguous view should fail"
);
});
}
#[test]
fn slice_full_range_is_identity() {
common::with_test_stack(|| {
let a = api::arange::<f32>(16).sync().expect("Failed.");
let v = a.slice(&[0..16]).expect("Failed.");
assert_eq!(v.shape(), &[16]);
assert_eq!(v.strides(), &[1]);
});
}
#[test]
fn slice_single_element() {
common::with_test_stack(|| {
let a = api::arange::<f32>(16).sync().expect("Failed.");
let v = a.slice(&[5..6]).expect("Failed.");
assert_eq!(v.shape(), &[1]);
});
}
#[test]
fn slice_2d_partial_axes() {
common::with_test_stack(|| {
let a = api::arange::<f32>(16).sync().expect("Failed.");
let a2d = a.view(&[4, 4]).expect("Failed.");
let v = a2d.slice(&[1..3]).expect("Failed.");
assert_eq!(v.shape(), &[2, 4]);
assert_eq!(v.strides(), &[4, 1]);
});
}
#[test]
fn slice_chained_accumulates_offset() {
common::with_test_stack(|| {
let a = api::arange::<f32>(16).sync().expect("Failed.");
let s1 = a.slice(&[2..10]).expect("Failed.");
assert_eq!(s1.shape(), &[8]);
let s2 = s1.slice(&[3..5]).expect("Failed.");
assert_eq!(s2.shape(), &[2]);
let output = api::zeros::<f32>(&[2]).sync().expect("Failed.");
let (result, _) = tensor_views_module::passthrough_1d_2(output.partition([2]), &s2)
.sync()
.expect("Failed.");
let host: Vec<f32> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert_eq!(host, vec![5.0, 6.0]);
});
}