use crate::{
DType, DebugMask, Map, Set, ZyxError,
backend::{AutotuneConfig, BufferId, Device, DeviceId, Event, MemoryPool, PoolId},
dtype::Constant,
graph::{Graph, Node},
kernel::{BOp, Kernel, MoveOp, Op, OpId, OpNode, Scope, UOp},
kernel_cache::KernelCache,
runtime::{Runtime, deallocate_tensors},
schedule::schedule,
slab::{Slab, SlabId},
tensor::TensorId,
view::View,
};
use std::collections::{BTreeMap, BTreeSet};
use std::hash::BuildHasherDefault;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct KMKernelId(u32);
impl SlabId for KMKernelId {
const ZERO: Self = Self(0);
const NULL: Self = Self(u32::MAX);
fn inc(&mut self) {
self.0 += 1;
}
}
impl From<usize> for KMKernelId {
fn from(value: usize) -> Self {
KMKernelId(value as u32)
}
}
impl From<KMKernelId> for usize {
fn from(value: KMKernelId) -> Self {
value.0 as usize
}
}
struct Kernelizer<'a> {
must_keep_nodes: Set<TensorId>, virt_realized_nodes: Set<TensorId>, realized_nodes: Set<TensorId>, kernels: Slab<KMKernelId, Kernel>,
visited: Map<TensorId, (KMKernelId, OpId)>,
rcs: Map<TensorId, u32>,
graph: &'a Graph,
pools: &'a mut Slab<PoolId, MemoryPool>,
events: &'a mut Map<BTreeSet<BufferId>, Event>,
buffer_map: &'a mut Map<TensorId, BufferId>,
temp_data: &'a mut Map<BufferId, Box<[u8]>>,
devices: &'a mut Slab<DeviceId, Device>,
cache: &'a mut KernelCache,
autotune_config: &'a AutotuneConfig,
debug: DebugMask,
n_launches: u32,
}
impl<'a> Kernelizer<'a> {
fn new(
realized_nodes: Set<TensorId>,
to_eval: &'a Set<TensorId>,
rcs: Map<TensorId, u32>,
graph: &'a Graph,
pools: &'a mut Slab<PoolId, MemoryPool>,
events: &'a mut Map<BTreeSet<BufferId>, Event>,
buffer_map: &'a mut Map<TensorId, BufferId>,
temp_data: &'a mut Map<BufferId, Box<[u8]>>,
devices: &'a mut Slab<DeviceId, Device>,
cache: &'a mut KernelCache,
search_config: &'a AutotuneConfig,
debug: DebugMask,
) -> Self {
let mut must_keep_nodes = realized_nodes.clone();
must_keep_nodes.extend(to_eval);
Self {
must_keep_nodes,
virt_realized_nodes: realized_nodes.clone(),
realized_nodes,
kernels: Slab::with_capacity(30),
visited: Map::with_capacity_and_hasher(100, BuildHasherDefault::new()),
rcs,
graph,
pools,
events,
buffer_map,
temp_data,
devices,
cache,
autotune_config: search_config,
debug,
n_launches: 0,
}
}
#[allow(unused)]
fn debug(&self) {
for kernel in self.kernels.values() {
kernel.debug();
}
println!();
}
fn is_virt_realized(&self, nid: TensorId) -> bool {
self.virt_realized_nodes.contains(&nid)
}
fn duplicate_or_store(&mut self, x: TensorId) -> Result<(KMKernelId, OpId), ZyxError> {
let (mut kid, mut op_id) = self.visited[&x];
if self.kernels[kid].contains_stores() {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
}
}
if self.kernels[kid].outputs.len() > 1 {
let reduce_dims_big = self.kernels[kid].is_preceded_by_reduce(op_id);
if reduce_dims_big {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
}
} else {
kid = self.duplicate_kernel(x, kid);
}
}
Ok((kid, op_id))
}
fn duplicate_kernel(&mut self, x: TensorId, kid: KMKernelId) -> KMKernelId {
let mut kernel = self.kernels[kid].clone();
kernel.outputs = vec![x];
kernel.drop_unused_ops(&self.visited);
self.kernels[kid].remove_first_output(x);
self.kernels[kid].drop_unused_ops(&self.visited);
self.kernels.push(kernel)
}
fn create_load_kernel(&mut self, nid: TensorId) -> (KMKernelId, OpId) {
let shape = self.graph.shape(nid);
let dtype = self.graph.dtype(nid);
let mut ops = Slab::with_capacity(100);
let op = Op::LoadView(Box::new((dtype, View::contiguous(shape))));
let op_id = ops.push(OpNode { prev: OpId::NULL, next: OpId::NULL, op });
let kernel = Kernel {
outputs: vec![nid; self.rcs[&nid] as usize],
loads: vec![nid],
stores: Vec::new(),
ops,
head: op_id,
tail: op_id,
};
let kid = self.kernels.push(kernel);
self.visited.insert(nid, (kid, op_id));
(kid, op_id)
}
fn create_const_kernel(&mut self, nid: TensorId, value: Constant) {
let mut ops = Slab::with_capacity(100);
let op = Op::ConstView(Box::new((value, View::contiguous(&[1]))));
let op_id = ops.push(OpNode { prev: OpId::NULL, next: OpId::NULL, op });
let kernel = Kernel {
outputs: vec![nid; self.rcs[&nid] as usize],
loads: Vec::new(),
stores: Vec::new(),
ops,
head: op_id,
tail: op_id,
};
let kid = self.kernels.push(kernel);
self.visited.insert(nid, (kid, op_id));
}
fn add_expand_op(&mut self, nid: TensorId, x: TensorId) -> Result<(), ZyxError> {
let (mut kid, mut op_id) = self.visited[&x];
if self.kernels[kid].contains_stores() | self.kernels[kid].is_preceded_by_reduce(op_id) {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
}
}
if self.kernels[kid].outputs.len() > 1 {
let reduce_dims_big = self.kernels[kid].is_preceded_by_reduce(op_id);
if reduce_dims_big {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
}
} else {
kid = self.duplicate_kernel(x, kid);
}
}
let shape = self.graph.shape(nid);
let kernel = &mut self.kernels[kid];
let op_id = kernel.push_back(Op::Move { x: op_id, mop: Box::new(MoveOp::Expand { shape: shape.into() }) });
kernel.remove_first_output(x);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
*self.rcs.get_mut(&x).unwrap() -= 1;
debug_assert_eq!(self.graph.shape(nid), kernel.shape());
self.visited.insert(nid, (kid, op_id));
Ok(())
}
fn add_reshape_op(&mut self, nid: TensorId, x: TensorId) -> Result<(), ZyxError> {
debug_assert!(self.visited.contains_key(&x), "Missing tensor {x} in visited.");
let (kid, op_id) = self.duplicate_or_store(x)?;
let shape = self.graph.shape(nid);
let kernel = &mut self.kernels[kid];
let op_id = kernel.push_back(Op::Move { x: op_id, mop: Box::new(MoveOp::Reshape { shape: shape.into() }) });
kernel.remove_first_output(x);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
*self.rcs.get_mut(&x).unwrap() -= 1;
debug_assert_eq!(self.graph.shape(nid), kernel.shape());
self.visited.insert(nid, (kid, op_id));
Ok(())
}
fn add_permute_op(&mut self, nid: TensorId, x: TensorId) -> Result<(), ZyxError> {
let (kid, op_id) = self.duplicate_or_store(x)?;
let axes: Vec<_> = self.graph.axes(nid).into();
let kernel = &mut self.kernels[kid];
let shape = self.graph.shape(nid).into();
let op_id = kernel.push_back(Op::Move { x: op_id, mop: Box::new(MoveOp::Permute { axes, shape }) });
kernel.remove_first_output(x);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
*self.rcs.get_mut(&x).unwrap() -= 1;
debug_assert_eq!(self.graph.shape(nid), kernel.shape());
self.visited.insert(nid, (kid, op_id));
Ok(())
}
fn add_pad_op(&mut self, nid: TensorId, x: TensorId) -> Result<(), ZyxError> {
let (kid, op_id) = self.duplicate_or_store(x)?;
let padding = self.graph.padding(nid).into();
let kernel = &mut self.kernels[kid];
let shape = self.graph.shape(nid).into();
let op_id = kernel.push_back(Op::Move { x: op_id, mop: Box::new(MoveOp::Pad { padding, shape }) });
kernel.remove_first_output(x);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
*self.rcs.get_mut(&x).unwrap() -= 1;
debug_assert_eq!(self.graph.shape(nid), kernel.shape());
self.visited.insert(nid, (kid, op_id));
Ok(())
}
fn add_reduce_op(&mut self, nid: TensorId, x: TensorId, rop: BOp) -> Result<(), ZyxError> {
let axes = self.graph.axes(nid);
let shape = self.graph.shape(x);
let (mut kid, mut op_id) = self.visited[&x];
if self.kernels[kid].contains_stores() {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
}
}
if self.kernels[kid].outputs.len() > 1 {
let reduce_dims_big = self.kernels[kid].is_preceded_by_reduce(op_id);
if reduce_dims_big {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
}
} else {
kid = self.duplicate_kernel(x, kid);
}
}
#[cfg(debug_assertions)]
{
use crate::shape::UAxis;
let mut sorted_axes: Vec<UAxis> = axes.into();
sorted_axes.sort_unstable();
debug_assert_eq!(axes, sorted_axes, "Reduce axes must be sorted.");
}
{
let n = shape.len();
let mut permute_axes = Vec::with_capacity(n);
let max_axis = *axes.last().unwrap();
let mut ai = 0;
for i in 0..=max_axis {
if axes[ai] == i {
ai += 1;
} else {
permute_axes.push(i);
}
}
permute_axes.extend(max_axis + 1..n);
permute_axes.extend_from_slice(axes);
if !permute_axes.iter().copied().eq(0..permute_axes.len()) {
let shape = crate::shape::permute(self.graph.shape(x), &permute_axes);
op_id = self.kernels[kid]
.push_back(Op::Move { x: op_id, mop: Box::new(MoveOp::Permute { axes: permute_axes, shape }) });
}
}
let kernel = &mut self.kernels[kid];
op_id = kernel.push_back(Op::Reduce { x: op_id, rop, n_axes: axes.len() });
kernel.remove_first_output(x);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
*self.rcs.get_mut(&x).unwrap() -= 1;
if shape.len() == axes.len() {
op_id = self.kernels[kid].push_back(Op::Move { x: op_id, mop: Box::new(MoveOp::Reshape { shape: vec![1] }) });
}
debug_assert_eq!(self.graph.shape(nid), self.kernels[kid].shape());
self.visited.insert(nid, (kid, op_id));
Ok(())
}
fn add_cast_op(&mut self, nid: TensorId, x: TensorId, dtype: DType) {
let (kid, op_id) = self.visited[&x];
let kernel = &mut self.kernels[kid];
let op_id = kernel.push_back(Op::Cast { x: op_id, dtype });
kernel.remove_first_output(x);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
*self.rcs.get_mut(&x).unwrap() -= 1;
self.visited.insert(nid, (kid, op_id));
}
fn add_unary_op(&mut self, nid: TensorId, x: TensorId, uop: UOp) {
let (kid, op_id) = self.visited[&x];
let kernel = &mut self.kernels[kid];
let op_id = kernel.push_back(Op::Unary { x: op_id, uop });
kernel.remove_first_output(x);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
*self.rcs.get_mut(&x).unwrap() -= 1;
self.visited.insert(nid, (kid, op_id));
}
fn add_binary_op(&mut self, nid: TensorId, mut x: TensorId, mut y: TensorId, bop: BOp) -> Result<(), ZyxError> {
let (mut kid, mut op_id) = self.visited[&x];
let (mut kidy, mut op_idy) = self.visited[&y];
let kid_stores = !self.kernels[kid].stores.is_empty();
let kidy_stores = !self.kernels[kidy].stores.is_empty();
let new_op_id = if kid == kidy {
let kernel = &mut self.kernels[kid];
kernel.remove_first_output(x);
kernel.remove_first_output(y);
kernel.outputs.extend(vec![nid; self.rcs[&nid] as usize]);
kernel.push_back(Op::Binary { x: op_id, y: op_idy, bop })
} else {
match (kid_stores, kidy_stores) {
(true, true) => {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
self.kernels[kid].outputs.push(x);
}
self.add_store(y)?;
(kidy, op_idy) = self.create_load_kernel(y);
if self.kernels[kidy].outputs.len() > 1 {
kidy = self.duplicate_kernel(y, kidy);
self.kernels[kidy].outputs.push(y);
}
}
(true, false) => {
self.add_store(x)?;
(kid, op_id) = self.create_load_kernel(x);
if self.kernels[kid].outputs.len() > 1 {
kid = self.duplicate_kernel(x, kid);
self.kernels[kid].outputs.push(x);
}
}
(false, true) => {
self.add_store(y)?;
(kidy, op_idy) = self.create_load_kernel(y);
if self.kernels[kidy].outputs.len() > 1 {
kidy = self.duplicate_kernel(y, kidy);
self.kernels[kidy].outputs.push(y);
}
}
(false, false) => {}
}
let swapped_xy = if self.kernels[kidy].is_reduce() && !self.kernels[kid].is_reduce() {
std::mem::swap(&mut kid, &mut kidy);
std::mem::swap(&mut op_id, &mut op_idy);
std::mem::swap(&mut x, &mut y);
true
} else {
false
};
self.kernels[kidy].remove_first_output(y);
let Kernel { outputs, loads, stores, ops, head, tail: _ } = unsafe { self.kernels.remove_and_return(kidy) };
let mut y_ops_map = Map::with_capacity_and_hasher(5, BuildHasherDefault::new());
let mut i = head;
while !i.is_null() {
let mut op = ops[i].op.clone();
for param in op.parameters_mut() {
*param = y_ops_map[param];
}
let new_op_id = self.kernels[kid].push_back(op);
y_ops_map.insert(i, new_op_id);
i = ops[i].next;
}
for (kidm, op_id) in self.visited.values_mut() {
if *kidm == kidy {
*kidm = kid;
if let Some(new_op_id) = y_ops_map.get(op_id) {
*op_id = *new_op_id;
}
}
}
self.kernels[kid].loads.extend(loads);
self.kernels[kid].stores.extend(stores);
self.kernels[kid].remove_first_output(x);
self.kernels[kid].outputs.extend(outputs);
self.kernels[kid].outputs.extend(vec![nid; self.rcs[&nid] as usize]);
let op = if swapped_xy {
Op::Binary { x: y_ops_map[&op_idy], y: op_id, bop }
} else {
Op::Binary { x: op_id, y: y_ops_map[&op_idy], bop }
};
self.kernels[kid].push_back(op)
};
*self.rcs.get_mut(&x).unwrap() -= 1;
*self.rcs.get_mut(&y).unwrap() -= 1;
self.visited.insert(nid, (kid, new_op_id));
Ok(())
}
fn add_store(&mut self, x: TensorId) -> Result<(), ZyxError> {
let (kid, op_id) = self.visited[&x];
if self.virt_realized_nodes.contains(&x) {
self.visited.remove(&x).unwrap();
self.kernels[kid].outputs.retain(|&elem| elem != x);
} else {
self.visited.remove(&x).unwrap();
self.virt_realized_nodes.insert(x);
let dtype = self.graph.dtype(x);
self.kernels[kid].push_back(Op::StoreView { src: op_id, dtype });
self.kernels[kid].stores.push(x);
self.kernels[kid].outputs.retain(|&elem| elem != x);
}
if self.kernels[kid].outputs.is_empty() && self.kernels[kid].loads.iter().all(|x| self.realized_nodes.contains(x)) {
let kernel = unsafe { self.kernels.remove_and_return(kid) };
let loads = kernel.loads.clone();
let stores = kernel.stores.clone();
self.launch_kernel(kernel)?;
self.realized_nodes.extend(stores);
let mut to_remove = Set::with_capacity_and_hasher(1, BuildHasherDefault::new());
for tid in loads {
if !self.kernels.values().any(|kernel| kernel.loads.contains(&tid)) && !self.must_keep_nodes.contains(&tid) {
to_remove.insert(tid);
}
}
deallocate_tensors(&to_remove, self.pools, self.events, self.buffer_map, self.temp_data);
}
Ok(())
}
fn launch_kernel(&mut self, mut kernel: Kernel) -> Result<(), ZyxError> {
if kernel.stores.is_empty() {
println!("Empty stores in this kernel:");
kernel.debug();
panic!("Empty stores in this kernel:");
}
debug_assert!(!kernel.stores.is_empty());
debug_assert!(!kernel.ops.is_empty());
self.n_launches += 1;
let (dev_id, pool_id, event_wait_list, output_buffers, args) = schedule(
&kernel.loads,
&kernel.stores,
self.graph,
self.devices,
self.pools,
self.events,
self.buffer_map,
)?;
let device = &mut self.devices[dev_id];
let pool = &mut self.pools[pool_id];
let dev_info_id = self.cache.get_or_add_dev_info(device.info());
let kernel_id = if let Some(&kid) = self.cache.kernels.get(&kernel) {
if let Some(&program_id) = self.cache.programs.get(&(kid, dev_id)) {
if self.debug.kmd() {
println!("Kernel launch from memory pool {pool_id:?} with args: {args:?}");
}
let event = device.launch(program_id, pool, &args, event_wait_list)?;
self.events.insert(output_buffers, event);
return Ok(());
}
if let Some(opt_seq) = self.cache.optimizations.get(&(kid, dev_info_id)) {
opt_seq.apply(&mut kernel, device.info());
let program_id = device.compile(&kernel, self.debug.asm())?;
let event = device.launch(program_id, pool, &args, event_wait_list)?;
self.events.insert(output_buffers, event);
return Ok(());
}
kid
} else {
self.cache.insert_kernel(kernel.clone())
};
if self.debug.sched() {
kernel.debug();
}
let (flop, read, write) = kernel.flop_mem_rw();
kernel.unfold_movement_ops();
let global_indices = kernel.get_global_indices();
let max_global_dims = device.info().max_global_work_dims.len();
if global_indices.len() > max_global_dims {
let n = global_indices.len() + 1 - max_global_dims;
let loops: Vec<OpId> = global_indices.values().copied().take(n).collect();
kernel.merge_indices(&loops);
}
{
let mut indices = BTreeMap::new();
indices.insert(Scope::Global, BTreeMap::new());
indices.insert(Scope::Local, BTreeMap::new());
for (op_id, op_node) in kernel.ops.iter() {
if let Op::Index { scope, axis, .. } = op_node.op {
indices.get_mut(&scope).unwrap().insert(axis, op_id);
}
}
for (_, scoped_indices) in indices {
let mut ax = 0;
for &idx_id in scoped_indices.values() {
let Op::Index { axis, .. } = &mut kernel.ops[idx_id].op else { unreachable!() };
*axis = ax;
ax += 1;
}
}
kernel.verify();
}
let (program_id, opts) = kernel.autotune(&args, device, pool, self.autotune_config, flop, read, write, self.debug)?;
self.cache.programs.insert((kernel_id, dev_id), program_id);
self.cache.optimizations.insert((kernel_id, dev_info_id), opts);
let event = device.launch(program_id, pool, &args, event_wait_list)?;
self.events.insert(output_buffers, event);
Ok(())
}
}
impl Runtime {
pub(crate) fn realize_with_order(
&mut self,
rcs: Map<TensorId, u32>,
realized_nodes: Set<TensorId>,
order: &[TensorId],
to_eval: &Set<TensorId>,
) -> Result<(), ZyxError> {
#[cfg(debug_assertions)]
{
}
let begin = std::time::Instant::now();
let mut kernelizer = Kernelizer::new(
realized_nodes,
to_eval,
rcs,
&self.graph,
&mut self.pools,
&mut self.events,
&mut self.buffer_map,
&mut self.temp_data,
&mut self.devices,
&mut self.kernel_cache,
&self.autotune_config,
self.debug,
);
for &nid in order {
if kernelizer.is_virt_realized(nid) {
kernelizer.create_load_kernel(nid);
} else {
match self.graph[nid] {
Node::Leaf { .. } => unreachable!(),
Node::Const { value } => kernelizer.create_const_kernel(nid, value),
Node::Cast { x, dtype } => kernelizer.add_cast_op(nid, x, dtype),
Node::Unary { x, uop } => kernelizer.add_unary_op(nid, x, uop),
Node::Expand { x } => kernelizer.add_expand_op(nid, x)?,
Node::Permute { x } => kernelizer.add_permute_op(nid, x)?,
Node::Reshape { x } => kernelizer.add_reshape_op(nid, x)?,
Node::Pad { x } => kernelizer.add_pad_op(nid, x)?,
Node::Reduce { x, rop } => kernelizer.add_reduce_op(nid, x, rop)?,
Node::Binary { x, y, bop } => kernelizer.add_binary_op(nid, x, y, bop)?,
Node::Custom(_) => todo!(),
}
}
if to_eval.contains(&nid) && !kernelizer.realized_nodes.contains(&nid) {
kernelizer.add_store(nid)?;
*kernelizer.rcs.get_mut(&nid).unwrap() -= 1;
if kernelizer.rcs[&nid] > 0 {
kernelizer.create_load_kernel(nid);
}
}
}
if kernelizer.kernels.len() > KMKernelId(0) {
let mut kids: Vec<KMKernelId> = kernelizer.kernels.ids().collect();
while let Some(kid) = kids
.iter()
.find(|&&kid| {
kernelizer.kernels[kid]
.loads
.iter()
.all(|x| kernelizer.realized_nodes.contains(x))
})
.copied()
{
kids.retain(|x| *x != kid);
let kernel = unsafe { kernelizer.kernels.remove_and_return(kid) };
let loads = kernel.loads.clone();
if !kernel.stores.is_empty() {
let stores = kernel.stores.clone();
kernelizer.launch_kernel(kernel)?;
kernelizer.realized_nodes.extend(stores);
}
let mut to_remove = Set::with_capacity_and_hasher(1, BuildHasherDefault::new());
for tid in loads {
if !kernelizer.kernels.values().any(|kernel| kernel.loads.contains(&tid))
&& !kernelizer.must_keep_nodes.contains(&tid)
{
to_remove.insert(tid);
}
}
deallocate_tensors(
&to_remove,
kernelizer.pools,
kernelizer.events,
kernelizer.buffer_map,
kernelizer.temp_data,
);
}
}
#[cfg(debug_assertions)]
{
assert!(kernelizer.kernels.len() <= KMKernelId(0));
debug_assert!(to_eval.is_subset(&kernelizer.realized_nodes));
}
let elapsed = begin.elapsed();
if self.debug.perf() {
println!(
"Kernelizer took {} μs for {} kernels",
elapsed.as_micros(),
kernelizer.n_launches
);
}
Ok(())
}
}