pub mod executor;
pub use executor::{execute_graph, KernelDispatch};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorOp {
MulMat,
Add,
RmsNorm,
Rope,
SoftMax,
Mul,
Silu,
Copy,
None,
}
#[derive(Debug, Clone)]
pub struct TensorNode {
pub op: TensorOp,
pub data_ptr: u64,
pub shape: [u32; 4],
pub inputs: Vec<usize>,
pub params: OpParams,
}
#[derive(Debug, Clone, Default)]
pub struct OpParams {
pub weight_ptr: u64,
pub gamma_ptr: u64,
pub scalar: f32,
pub int_param: u32,
}
#[derive(Debug, Clone)]
pub struct ComputeGraph {
pub nodes: Vec<TensorNode>,
pub n_leafs: usize,
}
impl ComputeGraph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
n_leafs: 0,
}
}
pub fn add_leaf(&mut self, data_ptr: u64, shape: [u32; 4]) -> usize {
let idx = self.nodes.len();
self.nodes.push(TensorNode {
op: TensorOp::None,
data_ptr,
shape,
inputs: Vec::new(),
params: OpParams::default(),
});
self.n_leafs += 1;
idx
}
pub fn add_op(
&mut self,
op: TensorOp,
data_ptr: u64,
shape: [u32; 4],
inputs: Vec<usize>,
params: OpParams,
) -> usize {
let idx = self.nodes.len();
self.nodes.push(TensorNode {
op,
data_ptr,
shape,
inputs,
params,
});
idx
}
pub fn n_ops(&self) -> usize {
self.nodes.len() - self.n_leafs
}
}
impl Default for ComputeGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_graph() {
let g = ComputeGraph::new();
assert_eq!(g.nodes.len(), 0);
assert_eq!(g.n_leafs, 0);
assert_eq!(g.n_ops(), 0);
}
#[test]
fn test_simple_graph() {
let mut g = ComputeGraph::new();
let input = g.add_leaf(0x1000, [1536, 1, 1, 0]);
let normed = g.add_op(
TensorOp::RmsNorm,
0x2000,
[1536, 1, 1, 0],
vec![input],
OpParams {
gamma_ptr: 0x3000,
scalar: 1e-6,
..Default::default()
},
);
let q = g.add_op(
TensorOp::MulMat,
0x4000,
[1536, 1, 1, 0],
vec![normed],
OpParams {
weight_ptr: 0x5000,
..Default::default()
},
);
assert_eq!(g.nodes.len(), 3);
assert_eq!(g.n_leafs, 1);
assert_eq!(g.n_ops(), 2);
assert_eq!(g.nodes[q].op, TensorOp::MulMat);
assert_eq!(g.nodes[q].inputs, vec![normed]);
}
}