use super::{ComputeGraph, TensorNode, TensorOp};
#[derive(Debug)]
pub struct GraphExecResult {
pub n_launches: usize,
pub elapsed_us: Option<u64>,
}
pub trait KernelDispatch {
fn dispatch_mul_mat(
&mut self,
node: &TensorNode,
input_ptr: u64,
output_ptr: u64,
m: u32,
n: u32,
k: u32,
) -> Result<(), crate::GpuError>;
fn dispatch_rms_norm(
&mut self,
node: &TensorNode,
input_ptr: u64,
output_ptr: u64,
hidden_dim: u32,
m: u32,
epsilon: f32,
) -> Result<(), crate::GpuError>;
fn dispatch_add(
&mut self,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n_elements: usize,
) -> Result<(), crate::GpuError>;
fn dispatch_rope(
&mut self,
node: &TensorNode,
qk_ptr: u64,
positions: &[u32],
head_dim: u32,
num_heads: u32,
) -> Result<(), crate::GpuError>;
fn dispatch_attention(
&mut self,
node: &TensorNode,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
output_ptr: u64,
m: u32,
layer_idx: usize,
) -> Result<(), crate::GpuError>;
fn dispatch_copy(
&mut self,
src_ptr: u64,
dst_ptr: u64,
size_bytes: usize,
) -> Result<(), crate::GpuError>;
fn dispatch_mul(
&mut self,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n_elements: usize,
) -> Result<(), crate::GpuError>;
fn dispatch_silu(
&mut self,
input_ptr: u64,
output_ptr: u64,
n_elements: usize,
) -> Result<(), crate::GpuError>;
}
pub fn execute_graph<D: KernelDispatch>(
graph: &ComputeGraph,
dispatcher: &mut D,
) -> Result<usize, crate::GpuError> {
let mut n_launches = 0;
for node in &graph.nodes {
match node.op {
TensorOp::None => {
}
TensorOp::MulMat => {
let input_idx = node.inputs.first().copied().unwrap_or(0);
let input_ptr = graph.nodes[input_idx].data_ptr;
dispatcher.dispatch_mul_mat(
node,
input_ptr,
node.data_ptr,
node.shape[2], node.shape[0], node.shape[1], )?;
n_launches += 1;
}
TensorOp::RmsNorm => {
let input_idx = node.inputs.first().copied().unwrap_or(0);
let input_ptr = graph.nodes[input_idx].data_ptr;
dispatcher.dispatch_rms_norm(
node,
input_ptr,
node.data_ptr,
node.shape[0], node.shape[2], node.params.scalar, )?;
n_launches += 1;
}
TensorOp::Add => {
let a_idx = node.inputs.first().copied().unwrap_or(0);
let b_idx = node.inputs.get(1).copied().unwrap_or(0);
let a_ptr = graph.nodes[a_idx].data_ptr;
let b_ptr = graph.nodes[b_idx].data_ptr;
let n_elements = (node.shape[0] * node.shape[2]) as usize;
dispatcher.dispatch_add(a_ptr, b_ptr, node.data_ptr, n_elements)?;
n_launches += 1;
}
TensorOp::Rope => {
let input_idx = node.inputs.first().copied().unwrap_or(0);
let input_ptr = graph.nodes[input_idx].data_ptr;
dispatcher.dispatch_rope(
node,
input_ptr,
&[], node.shape[0], node.shape[1], )?;
n_launches += 1;
}
TensorOp::SoftMax => {
let q_idx = node.inputs.first().copied().unwrap_or(0);
let k_idx = node.inputs.get(1).copied().unwrap_or(0);
let v_idx = node.inputs.get(2).copied().unwrap_or(0);
dispatcher.dispatch_attention(
node,
graph.nodes[q_idx].data_ptr,
graph.nodes[k_idx].data_ptr,
graph.nodes[v_idx].data_ptr,
node.data_ptr,
node.shape[2], node.params.int_param as usize, )?;
n_launches += 1;
}
TensorOp::Copy => {
let src_idx = node.inputs.first().copied().unwrap_or(0);
let src_ptr = graph.nodes[src_idx].data_ptr;
let size = (node.shape[0] * node.shape[1] * 4) as usize; dispatcher.dispatch_copy(src_ptr, node.data_ptr, size)?;
n_launches += 1;
}
TensorOp::Mul => {
let a_idx = node.inputs.first().copied().unwrap_or(0);
let b_idx = node.inputs.get(1).copied().unwrap_or(0);
let n_elements = (node.shape[0] * node.shape[2]) as usize;
dispatcher.dispatch_mul(
graph.nodes[a_idx].data_ptr,
graph.nodes[b_idx].data_ptr,
node.data_ptr,
n_elements,
)?;
n_launches += 1;
}
TensorOp::Silu => {
let input_idx = node.inputs.first().copied().unwrap_or(0);
let n_elements = (node.shape[0] * node.shape[2]) as usize;
dispatcher.dispatch_silu(
graph.nodes[input_idx].data_ptr,
node.data_ptr,
n_elements,
)?;
n_launches += 1;
}
}
}
Ok(n_launches)
}
#[cfg(test)]
mod tests {
use super::*;
struct CountingDispatcher {
launches: usize,
}
impl KernelDispatch for CountingDispatcher {
fn dispatch_mul_mat(
&mut self,
_: &TensorNode,
_: u64,
_: u64,
_: u32,
_: u32,
_: u32,
) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
fn dispatch_rms_norm(
&mut self,
_: &TensorNode,
_: u64,
_: u64,
_: u32,
_: u32,
_: f32,
) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
fn dispatch_add(
&mut self,
_: u64,
_: u64,
_: u64,
_: usize,
) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
fn dispatch_rope(
&mut self,
_: &TensorNode,
_: u64,
_: &[u32],
_: u32,
_: u32,
) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
fn dispatch_attention(
&mut self,
_: &TensorNode,
_: u64,
_: u64,
_: u64,
_: u64,
_: u32,
_: usize,
) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
fn dispatch_copy(&mut self, _: u64, _: u64, _: usize) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
fn dispatch_mul(
&mut self,
_: u64,
_: u64,
_: u64,
_: usize,
) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
fn dispatch_silu(&mut self, _: u64, _: u64, _: usize) -> Result<(), crate::GpuError> {
self.launches += 1;
Ok(())
}
}
#[test]
fn test_execute_empty_graph() {
let g = ComputeGraph::new();
let mut d = CountingDispatcher { launches: 0 };
let n = execute_graph(&g, &mut d).unwrap();
assert_eq!(n, 0);
assert_eq!(d.launches, 0);
}
#[test]
fn test_execute_single_layer_graph() {
use super::super::OpParams;
let mut g = ComputeGraph::new();
let input = g.add_leaf(0x1000, [1536, 1, 4, 0]);
let normed = g.add_op(
TensorOp::RmsNorm,
0x2000,
[1536, 1, 4, 0],
vec![input],
OpParams {
gamma_ptr: 0x3000,
scalar: 1e-6,
..Default::default()
},
);
let q = g.add_op(
TensorOp::MulMat,
0x4000,
[1536, 1536, 4, 0],
vec![normed],
OpParams {
weight_ptr: 0x5000,
..Default::default()
},
);
let k = g.add_op(
TensorOp::MulMat,
0x6000,
[256, 1536, 4, 0],
vec![normed],
OpParams {
weight_ptr: 0x7000,
..Default::default()
},
);
let v = g.add_op(
TensorOp::MulMat,
0x8000,
[256, 1536, 4, 0],
vec![normed],
OpParams {
weight_ptr: 0x9000,
..Default::default()
},
);
let attn = g.add_op(
TensorOp::SoftMax,
0xA000,
[1536, 1, 4, 0],
vec![q, k, v],
OpParams {
int_param: 0,
..Default::default()
},
);
let _residual = g.add_op(
TensorOp::Add,
0xB000,
[1536, 1, 4, 0],
vec![input, attn],
OpParams::default(),
);
let mut d = CountingDispatcher { launches: 0 };
let n = execute_graph(&g, &mut d).unwrap();
assert_eq!(n, 6);
assert_eq!(d.launches, 6);
assert_eq!(g.n_ops(), 6);
}
}