use tokitai_operator::backend::cpu::CpuScalarBackend;
use tokitai_operator::backend::{Backend, TensorStore};
use tokitai_operator::domain::DomainId;
use tokitai_operator::ir::SemanticGraph;
use tokitai_operator::object::{ObjectMeta, Representation, Shape, Tensor};
use tokitai_operator::op::{
BroadcastOp, ConcatOp, FlattenOp, PermuteOp, ReshapeOp, SliceOp, SqueezeOp, TransposeOp,
UnsqueezeOp,
};
use tokitai_operator::planner::HeuristicPlanner;
fn int_meta(shape: Vec<usize>) -> ObjectMeta {
ObjectMeta::tensor(
DomainId::new("integer"),
Shape::from(shape),
Representation::dense_cpu(),
)
}
fn cpu() -> CpuScalarBackend {
CpuScalarBackend
}
#[test]
fn reshape_2x3_to_3x2_preserves_data() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![2, 3]));
let tgt = g.add_input(int_meta(vec![3, 2]));
let out = g.add_op(ReshapeOp, &[t, tgt]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 3]),
vec![1, 2, 3, 4, 5, 6],
),
);
store.insert(
tgt,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![3, 2]),
vec![0; 6],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 2, 3, 4, 5, 6]);
assert_eq!(out_t.meta.shape.dims.len(), 2);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![3, 2]);
}
#[test]
fn reshape_rejects_dim_product_mismatch() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![2, 3]));
let tgt = g.add_input(int_meta(vec![4, 2]));
g.add_op(ReshapeOp, &[t, tgt]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 3]),
vec![1, 2, 3, 4, 5, 6],
),
);
store.insert(
tgt,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![4, 2]),
vec![0; 8],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
let err = cpu().execute_i64(&g, &plan, &mut store).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("reshape"), "got: {msg}");
}
#[test]
fn transpose_swap_axes_swaps_data() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![2, 3]));
let axes = g.add_input(int_meta(vec![2]));
let out = g.add_op(TransposeOp, &[t, axes]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 3]),
vec![1, 2, 3, 4, 5, 6],
),
);
store.insert(
axes,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![2]), vec![1, 0]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 4, 2, 5, 3, 6]);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![3, 2]);
}
#[test]
fn permute_3d_full_permutation() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![2, 3, 4]));
let perm = g.add_input(int_meta(vec![3]));
let out = g.add_op(PermuteOp, &[t, perm]).unwrap();
let mut store = TensorStore::new();
let data: Vec<i64> = (0..24).collect();
store.insert(
t,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![2, 3, 4]), data),
);
store.insert(
perm,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![3]),
vec![2, 0, 1],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![4, 2, 3]);
let input = (0..24i64).collect::<Vec<_>>();
let idx = |i: usize, j: usize, k: usize| -> i64 { input[j * 12 + k * 4 + i] };
let mut expected = Vec::with_capacity(24);
for i in 0..4 {
for j in 0..2 {
for k in 0..3 {
expected.push(idx(i, j, k));
}
}
}
assert_eq!(out_t.data, expected);
}
#[test]
fn slice_half_open_keeps_order() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![6]));
let bounds = g.add_input(int_meta(vec![3]));
let out = g.add_op(SliceOp, &[t, bounds]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![6]),
vec![10, 20, 30, 40, 50, 60],
),
);
store.insert(
bounds,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![3]),
vec![0, 1, 4],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![20, 30, 40]);
}
#[test]
fn slice_rejects_out_of_bounds() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![4]));
let bounds = g.add_input(int_meta(vec![3]));
g.add_op(SliceOp, &[t, bounds]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![4]),
vec![1, 2, 3, 4],
),
);
store.insert(
bounds,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![3]),
vec![0, 0, 8],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
let err = cpu().execute_i64(&g, &plan, &mut store).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("slice"), "got: {msg}");
}
#[test]
fn concat_along_axis_0() {
let mut g = SemanticGraph::new();
let a = g.add_input(int_meta(vec![2, 3]));
let b = g.add_input(int_meta(vec![1, 3]));
let axis = g.add_input(int_meta(vec![1]));
let out = g.add_op(ConcatOp, &[a, b, axis]).unwrap();
let mut store = TensorStore::new();
store.insert(
a,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 3]),
vec![1, 2, 3, 4, 5, 6],
),
);
store.insert(
b,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![1, 3]),
vec![7, 8, 9],
),
);
store.insert(
axis,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![1]), vec![0]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![3, 3]);
}
#[test]
fn concat_along_axis_1() {
let mut g = SemanticGraph::new();
let a = g.add_input(int_meta(vec![2, 2]));
let b = g.add_input(int_meta(vec![2, 1]));
let axis = g.add_input(int_meta(vec![1]));
let out = g.add_op(ConcatOp, &[a, b, axis]).unwrap();
let mut store = TensorStore::new();
store.insert(
a,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 2]),
vec![1, 2, 3, 4],
),
);
store.insert(
b,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 1]),
vec![5, 6],
),
);
store.insert(
axis,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![1]), vec![1]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 2, 5, 3, 4, 6]);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![2, 3]);
}
#[test]
fn broadcast_scalar_to_n() {
let mut g = SemanticGraph::new();
let s = g.add_input(int_meta(vec![1]));
let tgt = g.add_input(int_meta(vec![4]));
let out = g.add_op(BroadcastOp, &[s, tgt]).unwrap();
let mut store = TensorStore::new();
store.insert(
s,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![1]), vec![7]),
);
store.insert(
tgt,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![4]), vec![0; 4]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![7, 7, 7, 7]);
}
#[test]
fn broadcast_2_1_to_2_3() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![2, 1]));
let tgt = g.add_input(int_meta(vec![2, 3]));
let out = g.add_op(BroadcastOp, &[t, tgt]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 1]),
vec![10, 20],
),
);
store.insert(
tgt,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 3]),
vec![0; 6],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![10, 10, 10, 20, 20, 20]);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![2, 3]);
}
#[test]
fn flatten_matrix_to_vector_preserves_order() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![2, 2]));
let out = g.add_op(FlattenOp, &[t]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 2]),
vec![1, 2, 3, 4],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 2, 3, 4]);
assert_eq!(out_t.meta.shape.dims.len(), 1);
assert_eq!(out_t.meta.shape.dims[0].value().unwrap(), 4);
}
#[test]
fn squeeze_drops_all_size_1_dims() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![1, 3, 1, 2]));
let out = g.add_op(SqueezeOp, &[t]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![1, 3, 1, 2]),
vec![1, 2, 3, 4, 5, 6],
),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 2, 3, 4, 5, 6]);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![3, 2]);
}
#[test]
fn unsqueeze_dim_0_inserts_leading_one() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![3]));
let axis = g.add_input(int_meta(vec![1]));
let out = g.add_op(UnsqueezeOp, &[t, axis]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![3]),
vec![1, 2, 3],
),
);
store.insert(
axis,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![1]), vec![0]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 2, 3]);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![1, 3]);
}
#[test]
fn unsqueeze_dim_neg1_inserts_trailing_one() {
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![3]));
let axis = g.add_input(int_meta(vec![1]));
let out = g.add_op(UnsqueezeOp, &[t, axis]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![3]),
vec![1, 2, 3],
),
);
store.insert(
axis,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![1]), vec![-1]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(out[0]).unwrap();
assert_eq!(out_t.data, vec![1, 2, 3]);
let dims: Vec<usize> = out_t
.meta
.shape
.dims
.iter()
.map(|d| d.value().unwrap())
.collect();
assert_eq!(dims, vec![3, 1]);
}
#[test]
fn reshape_add_reshape_does_not_alias() {
use tokitai_operator::op::AddOp;
let mut g = SemanticGraph::new();
let t = g.add_input(int_meta(vec![6]));
let tgt1 = g.add_input(int_meta(vec![2, 3]));
let reshaped = g.add_op(ReshapeOp, &[t, tgt1]).unwrap();
let reshaped_t = reshaped[0];
let rhs = g.add_input(int_meta(vec![2, 3]));
let sum = g.add_op(AddOp, &[reshaped_t, rhs]).unwrap();
let tgt2 = g.add_input(int_meta(vec![6]));
let reshaped2 = g.add_op(ReshapeOp, &[sum[0], tgt2]).unwrap();
let mut store = TensorStore::new();
store.insert(
t,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![6]),
vec![1, 2, 3, 4, 5, 6],
),
);
store.insert(
tgt1,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 3]),
vec![0; 6],
),
);
store.insert(
rhs,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![2, 3]),
vec![10, 20, 30, 40, 50, 60],
),
);
store.insert(
tgt2,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![6]), vec![0; 6]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(reshaped2[0]).unwrap();
assert_eq!(out_t.data, vec![11, 22, 33, 44, 55, 66]);
let src = store.get(t).unwrap();
assert_eq!(src.data, vec![1, 2, 3, 4, 5, 6]);
}
#[test]
fn concat_then_slice_then_reshape() {
let mut g = SemanticGraph::new();
let a = g.add_input(int_meta(vec![1, 3]));
let b = g.add_input(int_meta(vec![1, 3]));
let axis = g.add_input(int_meta(vec![1]));
let cat = g.add_op(ConcatOp, &[a, b, axis]).unwrap();
let bounds = g.add_input(int_meta(vec![3]));
let sliced = g.add_op(SliceOp, &[cat[0], bounds]).unwrap();
let tgt = g.add_input(int_meta(vec![3]));
let final_out = g.add_op(ReshapeOp, &[sliced[0], tgt]).unwrap();
let mut store = TensorStore::new();
store.insert(
a,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![1, 3]),
vec![1, 2, 3],
),
);
store.insert(
b,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![1, 3]),
vec![4, 5, 6],
),
);
store.insert(
axis,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![1]), vec![0]),
);
store.insert(
bounds,
Tensor::dense_cpu(
DomainId::new("integer"),
Shape::from(vec![3]),
vec![0, 1, 2],
),
);
store.insert(
tgt,
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(vec![3]), vec![0; 3]),
);
let plan = HeuristicPlanner::new(cpu().capabilities())
.plan(&g)
.unwrap();
cpu().execute_i64(&g, &plan, &mut store).unwrap();
let out_t = store.get(final_out[0]).unwrap();
assert_eq!(out_t.data, vec![4, 5, 6]);
}