import "primitives/core.futil";
import "primitives/memories/seq.futil";
import "primitives/binary_operators.futil";
import "primitives/math.futil";
component batch_flatten_1x4() -> () {
cells {
ref x = seq_mem_d3(32,1,2,2,1,2,2);
ref x1 = seq_mem_d2(32,1,4,1,3);
__i0 = std_reg(1);
__j0 = std_reg(2);
__k0 = std_reg(2);
__l_0 = std_reg(3);
add0 = std_add(3);
add1 = std_add(2);
add2 = std_add(2);
add3 = std_add(1);
const0 = std_const(3,0);
const1 = std_const(1,0);
const10 = std_const(1,1);
const2 = std_const(1,0);
const3 = std_const(2,0);
const4 = std_const(2,1);
const5 = std_const(2,0);
const6 = std_const(2,1);
const7 = std_const(3,1);
const8 = std_const(2,1);
const9 = std_const(2,1);
le0 = std_le(1);
le1 = std_le(2);
le2 = std_le(2);
x_read0_0 = std_reg(32);
}
wires {
comb group cond0 {
le0.left = __i0.out;
le0.right = const2.out;
}
comb group cond1 {
le1.left = __j0.out;
le1.right = const4.out;
}
comb group cond2 {
le2.left = __k0.out;
le2.right = const6.out;
}
group let0<"promotable"=1> {
__l_0.in = const0.out;
__l_0.write_en = 1'd1;
let0[done] = __l_0.done;
}
group let1<"promotable"=1> {
__i0.in = const1.out;
__i0.write_en = 1'd1;
let1[done] = __i0.done;
}
group let2<"promotable"=1> {
__j0.in = const3.out;
__j0.write_en = 1'd1;
let2[done] = __j0.done;
}
group let3<"promotable"=1> {
__k0.in = const5.out;
__k0.write_en = 1'd1;
let3[done] = __k0.done;
}
group let4<"promotable"=2> {
x_read0_0.in = x.read_data;
x_read0_0.write_en = x.done;
let4[done] = x_read0_0.done;
x.content_en = 1'd1;
x.addr2 = __k0.out;
x.addr1 = __j0.out;
x.addr0 = __i0.out;
}
group upd0<"promotable"=1> {
x1.content_en = 1'd1;
x1.addr1 = __l_0.out;
x1.addr0 = __i0.out;
x1.write_en = 1'd1;
x1.write_data = x_read0_0.out;
upd0[done] = x1.done;
}
group upd1<"promotable"=1> {
__l_0.write_en = 1'd1;
add0.left = __l_0.out;
add0.right = const7.out;
__l_0.in = add0.out;
upd1[done] = __l_0.done;
}
group upd2<"promotable"=1> {
__k0.write_en = 1'd1;
add1.left = __k0.out;
add1.right = const8.out;
__k0.in = add1.out;
upd2[done] = __k0.done;
}
group upd3<"promotable"=1> {
__j0.write_en = 1'd1;
add2.left = __j0.out;
add2.right = const9.out;
__j0.in = add2.out;
upd3[done] = __j0.done;
}
group upd4<"promotable"=1> {
__i0.write_en = 1'd1;
add3.left = __i0.out;
add3.right = const10.out;
__i0.in = add3.out;
upd4[done] = __i0.done;
}
}
control {
seq {
@pos(0) let0;
@pos(1) let1;
@bound(1) while le0.out with cond0 {
seq {
@pos(2) let2;
@bound(2) while le1.out with cond1 {
seq {
@pos(3) let3;
@bound(2) while le2.out with cond2 {
seq {
@pos(4) let4;
@pos(5) upd0;
@pos(6) upd1;
@pos(3) upd2;
}
}
@pos(2) upd3;
}
}
@pos(1) upd4;
}
}
}
}
}
component main() -> () {
cells {
@external x = seq_mem_d3(32, 1, 2, 2, 1, 2, 2);
@external x1 = seq_mem_d2(32, 1, 4, 1, 3);
batch_flatten_1x4_ = batch_flatten_1x4();
}
wires {
}
control {
seq {
@pos(0) invoke batch_flatten_1x4_[x=x, x1=x1]()();
}
}
}
metadata #{
0: let %x1: Tensor[(1, 4), int32] /* ty=Tensor[(1, 4), int32] span=from_string:4:3 */ = nn.batch_flatten(%x) /* ty=Tensor[(1, 4), int32] span=from_string:3:38 */;
}#