use serial_test::serial;
use tl_runtime::*;
fn assert_tensor_valid(t: *mut OpaqueTensor) {
assert!(!t.is_null(), "Tensor pointer is null");
}
fn safe_free(t: *mut OpaqueTensor) {
if !t.is_null() {
tl_tensor_free(t);
}
}
#[test]
#[serial]
fn test_tensor_creation_and_free() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let shape: Vec<usize> = vec![2, 2];
let t = tl_tensor_new(data.as_ptr(), 2, shape.as_ptr());
assert_tensor_valid(t);
let len = tl_tensor_len(t);
assert_eq!(len, 2);
safe_free(t);
}
#[test]
#[serial]
fn test_tensor_arithmetic() {
let data_a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0]; let data_b: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0]; let shape: Vec<usize> = vec![4];
let t_a = tl_tensor_new(data_a.as_ptr(), 1, shape.as_ptr());
let t_b = tl_tensor_new(data_b.as_ptr(), 1, shape.as_ptr());
let t_add = tl_tensor_add(t_a, t_b);
assert_tensor_valid(t_add);
safe_free(t_a);
safe_free(t_b);
safe_free(t_add);
}
#[test]
#[serial]
fn test_tensor_zeros() {
let shape: Vec<usize> = vec![2, 5];
let t = tl_tensor_zeros(2, shape.as_ptr(), false);
assert_tensor_valid(t);
safe_free(t);
}
fn get_item_f32(t: *mut OpaqueTensor, idx: usize) -> f32 {
let indices = [idx as i64];
tl_tensor_get_f32_md(t, indices.as_ptr(), 1)
}
fn get_scalar_f32(t: *mut OpaqueTensor) -> f32 {
let indices = [0i64; 0]; tl_tensor_get_f32_md(t, indices.as_ptr(), 0)
}
fn assert_approx_eq(a: f32, b: f32) {
let diff = (a - b).abs();
assert!(diff < 1e-4, "Expected {}, got {} (diff {})", b, a, diff);
}
#[test]
#[serial]
fn test_matmul() {
let data_a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let data_b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape_a: Vec<usize> = vec![2, 3];
let shape_b: Vec<usize> = vec![3, 2];
let t_a = tl_tensor_new(data_a.as_ptr(), 2, shape_a.as_ptr());
let t_b = tl_tensor_new(data_b.as_ptr(), 2, shape_b.as_ptr());
let t_c = tl_tensor_matmul(t_a, t_b);
assert_tensor_valid(t_c);
let shape_c = unsafe { (*t_c).0.dims().to_vec() };
assert_eq!(shape_c, vec![2, 2]);
let indices = [0, 0];
let val = tl_tensor_get_f32_md(t_c, indices.as_ptr(), 2);
assert_approx_eq(val, 22.0);
safe_free(t_a);
safe_free(t_b);
safe_free(t_c);
}
#[test]
#[serial]
fn test_sum() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = vec![4];
let t = tl_tensor_new(data.as_ptr(), 1, shape.as_ptr());
let t_sum = tl_tensor_sum(t);
assert_tensor_valid(t_sum);
let val = get_scalar_f32(t_sum);
assert_approx_eq(val, 10.0);
safe_free(t);
safe_free(t_sum);
}
#[test]
#[serial]
fn test_math_ops() {
let data = vec![1.0, 4.0, 9.0];
let shape = vec![3];
let t = tl_tensor_new(data.as_ptr(), 1, shape.as_ptr());
let t_sqrt = tl_tensor_sqrt(t);
assert_approx_eq(get_item_f32(t_sqrt, 0), 1.0);
assert_approx_eq(get_item_f32(t_sqrt, 1), 2.0);
assert_approx_eq(get_item_f32(t_sqrt, 2), 3.0);
let data_zero = vec![0.0];
let shape_zero = vec![1];
let t_zero = tl_tensor_new(data_zero.as_ptr(), 1, shape_zero.as_ptr());
let t_exp = tl_tensor_exp(t_zero);
assert_approx_eq(get_item_f32(t_exp, 0), 1.0);
safe_free(t);
safe_free(t_sqrt);
safe_free(t_zero);
safe_free(t_exp);
}
#[test]
#[serial]
fn test_basic_ops() {
let data_a = vec![10.0, 20.0, 30.0];
let data_b = vec![2.0, 5.0, 3.0];
let shape = vec![3];
let t_a = tl_tensor_new(data_a.as_ptr(), 1, shape.as_ptr());
let t_b = tl_tensor_new(data_b.as_ptr(), 1, shape.as_ptr());
let t_sub = tl_tensor_sub(t_a, t_b);
assert_approx_eq(get_item_f32(t_sub, 0), 8.0);
assert_approx_eq(get_item_f32(t_sub, 1), 15.0);
assert_approx_eq(get_item_f32(t_sub, 2), 27.0);
let t_mul = tl_tensor_mul(t_a, t_b);
assert_approx_eq(get_item_f32(t_mul, 0), 20.0);
let t_div = tl_tensor_div(t_a, t_b);
assert_approx_eq(get_item_f32(t_div, 0), 5.0);
assert_approx_eq(get_item_f32(t_div, 1), 4.0);
let factor_data = vec![2.0];
let t_factor = tl_tensor_new(factor_data.as_ptr(), 1, vec![1].as_ptr());
let t_pow = tl_tensor_pow(t_b, t_factor);
assert_approx_eq(get_item_f32(t_pow, 0), 4.0);
assert_approx_eq(get_item_f32(t_pow, 1), 25.0);
let t_log_input = tl_tensor_new(vec![10.0].as_ptr(), 1, vec![1].as_ptr());
let t_log = tl_tensor_log(t_log_input);
assert_approx_eq(get_item_f32(t_log, 0), 2.30258);
safe_free(t_a);
safe_free(t_b);
safe_free(t_sub);
safe_free(t_mul);
safe_free(t_div);
safe_free(t_factor);
safe_free(t_pow);
safe_free(t_log_input);
safe_free(t_log);
}
#[test]
#[serial]
fn test_reshape_transpose() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = vec![2, 2]; let t = tl_tensor_new(data.as_ptr(), 2, shape.as_ptr());
let t_transposed = tl_tensor_transpose(t, 0, 1);
let val = tl_tensor_get_f32_md(t_transposed, [0, 1].as_ptr(), 2);
assert_eq!(val, 3.0);
let shape_data_t = vec![4.0, 1.0];
let shape_shape = vec![2];
let shape_t = tl_tensor_new(shape_data_t.as_ptr(), 1, shape_shape.as_ptr());
let t_flat = tl_tensor_reshape_new(t, shape_t);
assert_tensor_valid(t_flat);
let dims = unsafe { (*t_flat).0.dims().to_vec() };
assert_eq!(dims, vec![4, 1]);
safe_free(t);
safe_free(t_transposed);
safe_free(shape_t);
safe_free(t_flat);
}