use crate::{
Map,
dtype::Constant,
kernel::{IDX_T, Kernel, Op, OpId, Scope},
};
impl Constant {
pub fn from_f32_matrix(matrix: Vec<Vec<f32>>) -> Vec<Vec<Constant>> {
matrix
.into_iter()
.map(|row| row.into_iter().map(|v| Constant::F32(v.to_le_bytes())).collect())
.collect()
}
pub fn from_i32_matrix(matrix: Vec<Vec<i32>>) -> Vec<Vec<Constant>> {
matrix
.into_iter()
.map(|row| row.into_iter().map(|v| Constant::I32(v)).collect())
.collect()
}
}
impl Kernel {
pub fn emulate(&self, params: Vec<Vec<Constant>>) {
self.debug();
let num_space = 5;
let mut gws = vec![1; 3];
let mut lws = vec![1; 3];
for node in self.ops.values() {
if let Op::Index { len, scope, axis } = node.op {
let axis = axis as usize;
match scope {
Scope::Global => gws[axis] = len,
Scope::Local => lws[axis] = len,
Scope::Register => unreachable!(),
}
}
}
let n_threads = gws.iter().product::<usize>() * lws.iter().product::<usize>();
print!("Simulating gws={gws:?} lws={lws:?} ");
for i in 0..n_threads {
print!("{:>num_space$}", format!("T{i}"));
}
println!();
let mut index_combinations = vec![vec![0usize; n_threads]; 6];
let mut i = 0;
for gidx0 in 0..gws[0] {
for gidx1 in 0..gws[1] {
for gidx2 in 0..gws[2] {
for lidx0 in 0..lws[0] {
for lidx1 in 0..lws[1] {
for lidx2 in 0..lws[2] {
index_combinations[0][i] = gidx0;
index_combinations[1][i] = gidx1;
index_combinations[2][i] = gidx2;
index_combinations[3][i] = lidx0;
index_combinations[4][i] = lidx1;
index_combinations[5][i] = lidx2;
i += 1;
}
}
}
}
}
}
let mut reg_arrays: Map<OpId, Vec<Vec<Constant>>> = Map::default();
let mut globals_map = Map::default();
let mut regs: Map<OpId, Vec<Constant>> = Map::default();
let mut dtypes = Map::default();
let mut op_id = self.head;
while !op_id.is_null() {
let (text, dtype, reg_values) = match self.ops[op_id].op {
Op::Cast { x, dtype } => todo!(),
Op::Unary { x, uop } => todo!(),
Op::Binary { x, y, bop } => todo!(),
Op::Const(constant) => {
let dtype = constant.dtype();
(format!("r{op_id}: {dtype} = {constant}"), dtype, vec![constant; n_threads])
}
Op::Define { dtype, scope, ro, len } => {
match scope {
Scope::Global => {
globals_map.insert(op_id, globals_map.len());
}
Scope::Local => todo!(),
Scope::Register => {
reg_arrays.insert(op_id, vec![vec![dtype.zero_constant(); len]; n_threads]);
}
}
(format!("r{op_id}: {dtype} = def {scope}"), dtype, Vec::new())
}
Op::Store { dst, x, index, vlen } => {
let dtype = dtypes[&x];
(format!("r{op_id}: {dtype} = r{x}[{index}]"), dtype, Vec::new())
}
Op::Load { src, index, vlen } => todo!(),
Op::Index { len, scope, axis } => (
format!("r{op_id}: {IDX_T} = gidx{axis}"),
IDX_T,
match scope {
Scope::Global => index_combinations[axis as usize]
.iter()
.map(|&i| Constant::idx(i as u64))
.collect(),
Scope::Local => todo!(),
Scope::Register => todo!(),
},
),
Op::Loop { len, axis } => {
let dtype = IDX_T;
todo!()
}
Op::EndLoop => todo!(),
Op::Mad { x, y, z } => todo!(),
Op::WMMA { dims, layout, dtype, a, b, c } => todo!(),
Op::Vectorize { ref ops } => todo!(),
Op::Devectorize { vec, idx } => todo!(),
_ => todo!(),
};
print!("{text:>40}");
for val in ®_values {
print!("{:>num_space$}", val.to_string());
}
println!();
dtypes.insert(op_id, dtype);
regs.insert(op_id, reg_values);
op_id = self.next_op(op_id);
}
}
}