import torch
def f32_arr(tensor):
vals = tensor.detach().contiguous().float().flatten().tolist()
return "&[" + ", ".join(f"{v:.7f}" for v in vals) + "]"
def i64_arr(tensor):
vals = tensor.detach().contiguous().long().flatten().tolist()
return "&[" + ", ".join(f"{v}_i64" for v in vals) + "]"
def rand_uniform(shape, seed):
torch.manual_seed(seed)
return torch.rand(shape) * 1.9 + 0.1
def print_input(name, t):
print(f"#[rustfmt::skip]")
print(f"const {name}: &[f32] = {f32_arr(t)};")
def print_suite(prefix, ta, tb):
print(f" // unary")
for name, result in [
("sin", torch.sin(ta)),
("cos", torch.cos(ta)),
("exp", torch.exp(ta)),
("log", torch.log(ta)),
("sqrt", torch.sqrt(ta)),
("abs", torch.abs(ta)),
("neg", -ta),
]:
print(f' cmp("{prefix}/{name}", &te.clone().{name}().into_data(), {f32_arr(result)});')
print(f" // binary")
for name, op, result in [
("add", "+", ta + tb),
("mul", "*", ta * tb),
("sub", "-", ta - tb),
("div", "/", ta / tb),
]:
print(f' cmp("{prefix}/{name}", &(te.clone() {op} te2.clone()).into_data(), {f32_arr(result)});')
print(f" // reduce")
for dim, result in [(2, torch.sum(ta, dim=2, keepdim=True)), (1, torch.sum(ta, dim=1, keepdim=True))]:
print(f' cmp("{prefix}/sum_dim{dim}", &te.clone().sum_dim({dim}).into_data(), {f32_arr(result)});')
print(f' cmp("{prefix}/mean_dim2", &te.clone().mean_dim(2).into_data(), {f32_arr(torch.mean(ta, dim=2, keepdim=True))});')
argmax = torch.argmax(ta, dim=2, keepdim=True)
print(f' cmp_i64("{prefix}/argmax_dim2", &te.clone().argmax(2).into_data(), {i64_arr(argmax)});')
print(f" // softmax")
print(f' cmp("{prefix}/softmax_dim2", &burn_tensor::activation::softmax(te.clone(), 2).into_data(), {f32_arr(torch.softmax(ta, dim=2))});')
def main():
print("\n// === Scenario 1: step2_last [2,3,4] -> [2,3,2] ===")
a = rand_uniform([2, 3, 4], seed=100)
b = rand_uniform([2, 3, 4], seed=101)
print_input("STEP2_LAST_A", a)
print_input("STEP2_LAST_B", b)
ta, tb = a[:, :, ::2], b[:, :, ::2]
print_suite("step2_last", ta, tb)
mm = torch.matmul(ta, ta.transpose(-1, -2))
print(f' cmp("step2_last/matmul", &te.clone().matmul(te.transpose()).into_data(), {f32_arr(mm)});')
print("\n// === Scenario 2: step3_mid [2,6,3] -> [2,2,3] ===")
a = rand_uniform([2, 6, 3], seed=200)
b = rand_uniform([2, 6, 3], seed=201)
print_input("STEP3_MID_A", a)
print_input("STEP3_MID_B", b)
ta, tb = a[:, ::3, :], b[:, ::3, :]
print_suite("step3_mid", ta, tb)
print("\n// === Scenario 3: narrow_transpose [2,4,3] -> [2,3,2] ===")
a = rand_uniform([2, 4, 3], seed=300)
b = rand_uniform([2, 4, 3], seed=301)
print_input("NARROW_TRANS_A", a)
print_input("NARROW_TRANS_B", b)
ta = a.narrow(1, 1, 2).transpose(1, 2)
tb = b.narrow(1, 1, 2).transpose(1, 2)
print_suite("narrow_trans", ta, tb)
print("\n// === Scenario 4: step2_transpose [2,3,4] -> [2,2,3] ===")
a = rand_uniform([2, 3, 4], seed=400)
b = rand_uniform([2, 3, 4], seed=401)
print_input("STEP2_TRANS_A", a)
print_input("STEP2_TRANS_B", b)
ta = a[:, :, ::2].transpose(1, 2)
tb = b[:, :, ::2].transpose(1, 2)
print_suite("step2_trans", ta, tb)
print("\n// === Scenario 5: expand_s0 [1,3,2] -> [2,3,2] ===")
a = rand_uniform([1, 3, 2], seed=500)
b = rand_uniform([1, 3, 2], seed=501)
print_input("EXPAND_S0_A", a)
print_input("EXPAND_S0_B", b)
ta, tb = a.expand(2, 3, 2), b.expand(2, 3, 2)
print(f" // unary subset")
print(f' cmp("expand_s0/sin", &te.clone().sin().into_data(), {f32_arr(torch.sin(ta))});')
print(f' cmp("expand_s0/exp", &te.clone().exp().into_data(), {f32_arr(torch.exp(ta))});')
print(f' cmp("expand_s0/neg", &te.clone().neg().into_data(), {f32_arr(-ta)});')
print(f' cmp("expand_s0/abs", &te.clone().abs().into_data(), {f32_arr(torch.abs(ta))});')
print(f" // binary")
print(f' cmp("expand_s0/add", &(te.clone() + te2.clone()).into_data(), {f32_arr(ta + tb)});')
print(f' cmp("expand_s0/mul", &(te.clone() * te2.clone()).into_data(), {f32_arr(ta * tb)});')
print(f" // reduce")
print(f' cmp("expand_s0/sum_dim0", &te.clone().sum_dim(0).into_data(), {f32_arr(torch.sum(ta, dim=0, keepdim=True))});')
print(f' cmp("expand_s0/sum_dim2", &te.clone().sum_dim(2).into_data(), {f32_arr(torch.sum(ta, dim=2, keepdim=True))});')
print(f' cmp("expand_s0/mean_dim1", &te.clone().mean_dim(1).into_data(), {f32_arr(torch.mean(ta, dim=1, keepdim=True))});')
print(f" // softmax")
print(f' cmp("expand_s0/softmax_dim2", &burn_tensor::activation::softmax(te.clone(), 2).into_data(), {f32_arr(torch.softmax(ta, dim=2))});')
print("\n// === Scenario 6: gather on step-sliced [2,2,4] -> [2,2,2] ===")
a = rand_uniform([2, 2, 4], seed=600)
print_input("GATHER_INPUT", a)
ta = a[:, :, ::2]
indices = torch.tensor([[[1], [0]], [[0], [1]]])
result = torch.gather(ta, 2, indices)
print(f' cmp("gather/step2", &te.gather(2, idx).into_data(), {f32_arr(result)});')
if __name__ == "__main__":
main()