pub mod compiled;
mod search;
use crate::kernel::{BOp, UOp};
use crate::slab::SlabId;
use crate::tensor::TensorId;
use crate::{
DType,
shape::{Dim, UAxis},
slab::Slab,
};
use crate::{Map, Set};
use std::hash::BuildHasherDefault;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
pub enum Node {
Const {
value: Constant,
},
Leaf {
dtype: DType,
},
Expand {
x: TensorId,
},
Permute {
x: TensorId,
},
Reshape {
x: TensorId,
},
Pad {
x: TensorId,
},
Reduce {
x: TensorId,
rop: BOp,
},
Cast {
x: TensorId,
dtype: DType,
},
Unary {
x: TensorId,
uop: UOp,
},
Binary {
x: TensorId,
y: TensorId,
bop: BOp,
},
#[allow(unused)]
Custom(Box<crate::kernel::custom::CustomKernel>),
}
#[derive(Debug)]
pub struct Graph {
pub nodes: Slab<TensorId, (u32, Node)>,
pub gradient_tape_ref_count: u32,
pub gradient_tape: Option<Set<TensorId>>,
pub shapes: Map<TensorId, Box<[Dim]>>,
paddings: Map<TensorId, Box<[(i64, i64)]>>,
axes: Map<TensorId, Box<[UAxis]>>,
}
impl Graph {
pub(super) const fn new() -> Self {
Self {
nodes: Slab::new(),
gradient_tape_ref_count: 0,
gradient_tape: None,
shapes: Map::with_hasher(BuildHasherDefault::new()),
paddings: Map::with_hasher(BuildHasherDefault::new()),
axes: Map::with_hasher(BuildHasherDefault::new()),
}
}
pub(super) fn is_empty(&self) -> bool {
self.nodes.len() == TensorId::ZERO
}
pub(super) fn retain(&mut self, x: TensorId) {
self.nodes[x].0 += 1;
}
pub(super) fn release(&mut self, x: &[TensorId]) -> Set<TensorId> {
let mut params = Vec::with_capacity(10);
params.extend(x);
let mut to_remove = Set::with_capacity_and_hasher(10, BuildHasherDefault::default());
while let Some(x) = params.pop() {
if let Some((rc, node)) = self.nodes.get_mut(x) {
let a = rc.saturating_sub(1);
*rc = a;
if a == 0 {
params.extend(node.parameters());
to_remove.insert(x);
self.nodes.remove(x);
_ = self.shapes.remove(&x);
_ = self.axes.remove(&x);
_ = self.paddings.remove(&x);
if let Some(tape) = self.gradient_tape.as_mut() {
_ = tape.remove(&x);
}
}
}
}
to_remove
}
pub(super) fn push(&mut self, node: Node) -> TensorId {
#[cfg(debug_assertions)]
{
let mut shape = None;
for nid in node.parameters() {
if let Some(sh) = shape {
let shape = self.shape(nid);
if sh != shape {
println!("{:?}", self.shapes);
panic!("{sh:?} != {shape:?} Pushing new node {node:?}");
}
} else {
shape = Some(self.shape(nid));
}
}
}
for nid in node.parameters() {
self.nodes[nid].0 += 1;
}
let nid = self.nodes.push((1, node));
if let Some(tape) = self.gradient_tape.as_mut() {
tape.insert(nid);
}
nid
}
pub(super) fn push_wshape(&mut self, node: Node, shape: Vec<Dim>) -> TensorId {
let id = self.push(node);
self.shapes.insert(id, shape.into_boxed_slice());
id
}
pub(super) fn push_padding(&mut self, id: TensorId, padding: Vec<(i64, i64)>) {
self.paddings.insert(id, padding.into_boxed_slice());
}
pub(super) fn push_axes(&mut self, id: TensorId, axes: Vec<UAxis>) {
self.axes.insert(id, axes.into_boxed_slice());
}
pub(super) fn add_shape(&mut self, id: TensorId) {
let shape = self.shape(id).into();
self.shapes.insert(id, shape);
}
pub(super) fn dtype(&self, tensor_id: TensorId) -> DType {
let mut tensor_id = tensor_id;
for _ in 0..100_000 {
match self.nodes[tensor_id].1 {
Node::Const { value } => return value.dtype(),
Node::Leaf { dtype } | Node::Cast { dtype, .. } => return dtype,
Node::Binary { bop, .. } if bop.returns_bool() => {
return DType::Bool;
}
_ => {
tensor_id = self.nodes[tensor_id].1.parameters().next().unwrap();
}
}
}
panic!("DType of {tensor_id:?} could not be found. This is internal bug.")
}
pub(super) fn padding(&self, tensor_id: TensorId) -> &[(i64, i64)] {
&self.paddings[&tensor_id]
}
pub(super) fn axes(&self, tensor_id: TensorId) -> &[UAxis] {
&self.axes[&tensor_id]
}
pub(super) fn shape(&self, tensor_id: TensorId) -> &[Dim] {
let mut tensor_id = tensor_id;
for _ in 0..1_000_000 {
if let Some(shape) = self.shapes.get(&tensor_id) {
return shape;
} else if let Node::Const { .. } = self.nodes[tensor_id].1 {
return &[1];
}
tensor_id = self.nodes[tensor_id].1.param1();
}
panic!("Shape of {tensor_id:?} could not be found. This is internal bug.")
}
pub(super) fn build_topo(&self, x: TensorId, sources: &Set<TensorId>) -> Vec<TensorId> {
let Some(tape) = self.gradient_tape.as_ref() else {
return Vec::new();
};
let mut params: Vec<TensorId> = vec![x];
let mut rcs: Map<TensorId, u32> = Map::with_capacity_and_hasher(100, BuildHasherDefault::new());
while let Some(nid) = params.pop() {
rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert_with(|| {
if matches!(
self.nodes[nid].1,
Node::Binary {
bop: BOp::Cmpgt
| BOp::Cmplt
| BOp::Eq
| BOp::NotEq
| BOp::Or
| BOp::And
| BOp::BitAnd
| BOp::BitOr
| BOp::BitXor
| BOp::BitShiftLeft
| BOp::BitShiftRight,
..
}
) {
return 1;
}
if tape.contains(&nid) {
params.extend(self.nodes[nid].1.parameters());
}
1
});
}
let mut order = Vec::new();
let mut internal_rcs: Map<TensorId, u32> = Map::with_capacity_and_hasher(100, BuildHasherDefault::new());
let mut params: Vec<TensorId> = vec![x];
while let Some(nid) = params.pop() {
if let Some(&rc) = rcs.get(&nid) {
if rc == *internal_rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1) {
order.push(nid);
params.extend(self.nodes[nid].1.parameters());
}
}
}
let mut topo = Vec::new();
let mut req_grad = sources.clone();
let mut visited = Set::with_capacity_and_hasher(100, BuildHasherDefault::new());
for nid in order.into_iter().rev() {
for p in self.nodes[nid].1.parameters() {
if req_grad.contains(&p) && visited.insert(nid) {
req_grad.insert(nid);
topo.push(nid);
}
}
}
topo.reverse();
topo
}
#[must_use]
pub fn plot_dot_graph(
&self,
ids: &Set<TensorId>,
buffer_map: &crate::Map<crate::tensor::TensorId, crate::backend::BufferId>,
) -> String {
use core::fmt::Write;
use std::format as f;
let ids: Set<TensorId> = if ids.is_empty() {
self.nodes.ids().collect()
} else {
ids.clone()
};
let mut params: Vec<TensorId> = ids.iter().copied().collect();
let mut rcs: Map<TensorId, u8> = Map::with_capacity_and_hasher(100, BuildHasherDefault::new());
while let Some(nid) = params.pop() {
rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert_with(|| {
params.extend(self.nodes[nid].1.parameters());
1
});
}
let mut order = Vec::new();
let mut internal_rcs: Map<TensorId, u8> = Map::with_capacity_and_hasher(100, BuildHasherDefault::new());
let mut params: Vec<TensorId> = ids.iter().copied().collect();
while let Some(nid) = params.pop() {
if rcs[&nid] == *internal_rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1) {
order.push(nid);
if rcs.contains_key(&nid) {
params.extend(self.nodes[nid].1.parameters());
}
}
}
let mut topo: Set<TensorId> = ids.iter().copied().collect();
for nid in order.into_iter().rev() {
for p in self.nodes[nid].1.parameters() {
if topo.contains(&p) {
topo.insert(nid);
}
}
}
let mut user_rc: Map<TensorId, u32> = self.nodes.iter().map(|(k, (rc, _))| (k, *rc)).collect();
for (_, node) in self.nodes.values() {
for param in node.parameters() {
*user_rc.get_mut(¶m).unwrap() -= 1;
}
}
let realized_nodes: Set<TensorId> = buffer_map.keys().copied().collect();
let mut res_dot_graph = String::from("strict digraph {\n ordering=in\n rank=source\n rankdir=LR\n");
let mut add_node = |i: TensorId, text: &str, shape: &str| {
let fillcolor = if user_rc[&i] > 0 { "coral" } else { "aqua" };
let (border_color, border_width) = if realized_nodes.contains(&i) {
("darkred", 5)
} else {
("black", 1)
};
write!(res_dot_graph, " {i}[label=\"{} x {}NL{}NL{:?}\", shape={}, fillcolor=\"{}\", style=filled, color=\"{border_color}\", penwidth={border_width}]", i, self.nodes[i].0, text, self.shape(i), shape, fillcolor).unwrap();
writeln!(res_dot_graph).unwrap();
};
let mut edges = String::new();
for &id in &topo {
let node = &self.nodes[id].1;
match node {
Node::Const { value } => add_node(id, &f!("Const({value:?})"), "box"),
Node::Leaf { dtype } => {
add_node(id, &f!("Leaf({:?}, {})", self.shape(id), dtype), "box");
}
Node::Cast { x, dtype } => add_node(id, &f!("C-{dtype}({x})"), "oval"),
Node::Unary { x, uop } => add_node(id, &f!("{uop:?}({x})"), "oval"),
Node::Binary { x, y, bop } => add_node(id, &f!("{bop:?}({x}, {y})"), "oval"),
Node::Reshape { x } => add_node(id, &f!("Reshape({x})"), "oval"),
Node::Permute { x } => add_node(id, &f!("Permute({x})"), "oval"),
Node::Expand { x } => add_node(id, &f!("Expand({x})"), "oval"),
Node::Pad { x } => add_node(id, &f!("Pad({x})"), "oval"),
Node::Reduce { x, rop } => add_node(id, &f!("{rop:?}({x})"), "oval"),
Node::Custom(_) => todo!(),
}
for param in node.parameters() {
writeln!(edges, " {param} -> {id}").unwrap();
}
}
res_dot_graph = res_dot_graph.replace("NL", "\n");
write!(res_dot_graph, "{edges}}}").unwrap();
res_dot_graph
}
}
impl std::ops::Index<TensorId> for Graph {
type Output = Node;
fn index(&self, index: TensorId) -> &Self::Output {
&self.nodes[index].1
}
}
impl std::ops::IndexMut<TensorId> for Graph {
fn index_mut(&mut self, index: TensorId) -> &mut Self::Output {
&mut self.nodes[index].1
}
}
use crate::dtype::Constant;
impl BOp {
pub const fn is_associative(self) -> bool {
use BOp::{Add, And, BitAnd, BitOr, BitShiftLeft, BitShiftRight, BitXor, Max, Mul, Or};
matches!(
self,
Add | Mul | And | Or | BitXor | BitAnd | BitOr | BitShiftLeft | BitShiftRight | Max
)
}
pub const fn is_commutative(self) -> bool {
use BOp::{Add, And, BitAnd, BitOr, BitXor, Max, Mul, Or};
matches!(self, Add | Mul | And | Or | BitXor | BitAnd | BitOr | Max)
}
pub const fn returns_bool(self) -> bool {
use BOp::{And, Cmpgt, Cmplt, Eq, NotEq, Or};
matches!(self, Cmpgt | Cmplt | NotEq | Eq | And | Or)
}
}
pub struct NodeParametersIterator {
parameters: [TensorId; 2],
len: u8,
idx: u8,
}
impl Iterator for NodeParametersIterator {
type Item = TensorId;
fn next(&mut self) -> Option<Self::Item> {
if self.idx == self.len {
return None;
}
let idx = self.idx;
self.idx += 1;
Some(self.parameters[idx as usize])
}
}
impl Node {
pub const fn parameters(&self) -> NodeParametersIterator {
match self {
Node::Const { .. } | Node::Leaf { .. } => {
NodeParametersIterator { parameters: [TensorId::ZERO, TensorId::ZERO], idx: 0, len: 0 }
}
Node::Unary { x, .. }
| Node::Cast { x, .. }
| Node::Reshape { x, .. }
| Node::Expand { x, .. }
| Node::Permute { x, .. }
| Node::Pad { x, .. }
| Node::Reduce { x, .. } => NodeParametersIterator { parameters: [*x, TensorId::ZERO], idx: 0, len: 1 },
Node::Binary { x, y, .. } => NodeParametersIterator { parameters: [*x, *y], idx: 0, len: 2 },
Node::Custom(_) => todo!(),
}
}
pub const fn param1(&self) -> TensorId {
match *self {
Node::Const { .. } | Node::Leaf { .. } => unreachable!(),
Node::Expand { x }
| Node::Permute { x }
| Node::Reshape { x }
| Node::Pad { x }
| Node::Reduce { x, .. }
| Node::Cast { x, .. }
| Node::Unary { x, .. }
| Node::Binary { x, .. } => x,
Node::Custom(_) => todo!(),
}
}
}