use crate::{Op, Shape};
use crate::provenance::NodeOrigin;
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub u32);
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "%{}", self.0)
}
}
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone)]
pub struct Node {
pub id: NodeId,
pub op: Op,
pub inputs: Vec<NodeId>,
pub shape: Shape,
pub name: Option<String>,
pub origin: Option<NodeOrigin>,
}
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Debug)]
pub struct Graph {
pub name: String,
nodes: Vec<Node>,
pub outputs: Vec<NodeId>,
}
impl PartialEq for Graph {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.nodes.len() == other.nodes.len()
&& self.outputs == other.outputs
}
}
impl Graph {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
nodes: Vec::new(),
outputs: Vec::new(),
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn node(&self, id: NodeId) -> &Node {
&self.nodes[id.0 as usize]
}
pub fn nodes(&self) -> &[Node] {
&self.nodes
}
pub fn shape(&self, id: NodeId) -> &Shape {
&self.nodes[id.0 as usize].shape
}
pub fn set_outputs(&mut self, outputs: Vec<NodeId>) {
self.outputs = outputs;
}
pub fn set_inputs(&mut self, id: NodeId, inputs: Vec<NodeId>) {
self.nodes[id.0 as usize].inputs = inputs;
}
pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
&mut self.nodes[id.0 as usize]
}
pub fn nodes_mut(&mut self) -> &mut [Node] {
&mut self.nodes
}
pub fn append_node(
&mut self,
op: Op,
inputs: Vec<NodeId>,
shape: Shape,
name: Option<String>,
) -> NodeId {
self.push(op, inputs, shape, name)
}
pub(crate) fn push(
&mut self,
op: Op,
inputs: Vec<NodeId>,
shape: Shape,
name: Option<String>,
) -> NodeId {
self.push_ext(op, inputs, shape, name, None)
}
pub(crate) fn push_ext(
&mut self,
op: Op,
inputs: Vec<NodeId>,
shape: Shape,
name: Option<String>,
origin: Option<NodeOrigin>,
) -> NodeId {
let id = NodeId(self.nodes.len() as u32);
self.nodes.push(Node {
id,
op,
inputs,
shape,
name,
origin,
});
id
}
pub fn users(&self, id: NodeId) -> Vec<NodeId> {
self.nodes
.iter()
.filter(|n| n.inputs.contains(&id))
.map(|n| n.id)
.collect()
}
pub fn use_count(&self, id: NodeId) -> usize {
self.nodes.iter().filter(|n| n.inputs.contains(&id)).count()
}
pub fn topo_order(&self) -> impl Iterator<Item = NodeId> + '_ {
(0..self.nodes.len()).map(|i| NodeId(i as u32))
}
pub fn reverse_topo(&self) -> impl Iterator<Item = NodeId> + '_ {
(0..self.nodes.len()).rev().map(|i| NodeId(i as u32))
}
pub fn define(
name: impl Into<String>,
build: impl FnOnce(&mut crate::hir::HirModule) -> crate::hir::HirNodeId,
) -> crate::GraphModule {
crate::GraphModule::define(name, build)
}
pub fn hir(name: impl Into<String>) -> crate::GraphModule {
crate::GraphModule::hir(name)
}
pub fn module(self) -> crate::GraphModule {
crate::GraphModule::from_graph(self)
}
pub fn from_hir(hir: crate::hir::HirModule) -> Result<Self, crate::hir::LowerError> {
hir.lower_to_mir().map(|m| m.into_graph())
}
pub fn to_mir(self) -> crate::MirModule {
crate::MirModule::from_graph(self)
}
pub fn from_lir(lir: crate::LirModule) -> Self {
lir.into_graph()
}
pub fn inspect(&self) -> String {
crate::inspect_graph(self)
}
pub fn has_dynamic_dims(&self) -> bool {
crate::dynamic::has_dynamic_dims(self)
}
pub fn dynamic_symbols(&self) -> Vec<u32> {
crate::dynamic::collect_dynamic_symbols(self)
}
pub fn bind(&self, bindings: &crate::DimBinding) -> Self {
crate::dynamic::bind_graph(self, bindings)
}
pub fn inspect_module(module: &crate::GraphModule) -> String {
module.inspect()
}
}
impl std::fmt::Display for Graph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "graph @{} {{", self.name)?;
for node in &self.nodes {
write!(f, " {} = {}", node.id, node.op)?;
if !node.inputs.is_empty() {
write!(f, "(")?;
for (i, inp) in node.inputs.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{inp}")?;
}
write!(f, ")")?;
}
writeln!(f, " : {}", node.shape)?;
}
if !self.outputs.is_empty() {
write!(f, " return ")?;
for (i, o) in self.outputs.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{o}")?;
}
writeln!(f)?;
}
writeln!(f, "}}")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
DType,
op::{Activation, BinaryOp},
};
#[test]
fn build_simple_graph() {
let mut g = Graph::new("test");
let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
let w = g.param("weight", Shape::new(&[384, 1536], DType::F32));
let b = g.param("bias", Shape::new(&[1536], DType::F32));
let mm = g.matmul(x, w, Shape::new(&[4, 15, 1536], DType::F32));
let add = g.binary(BinaryOp::Add, mm, b, Shape::new(&[4, 15, 1536], DType::F32));
let out = g.activation(
Activation::Gelu,
add,
Shape::new(&[4, 15, 1536], DType::F32),
);
g.set_outputs(vec![out]);
assert_eq!(g.len(), 6);
assert_eq!(g.use_count(mm), 1); assert_eq!(g.use_count(x), 1);
let printed = format!("{g}");
assert!(printed.contains("matmul(%0, %1)"));
assert!(printed.contains("Gelu(%4)"));
assert!(printed.contains("return %5"));
}
#[test]
fn bert_layer_graph() {
let mut g = Graph::new("bert_layer");
let f = DType::F32;
let h = 384;
let int = 1536;
let x = g.input("hidden", Shape::new(&[4, 15, h], f));
let qkv_w = g.param("qkv.weight", Shape::new(&[h, 3 * h], f));
let qkv_b = g.param("qkv.bias", Shape::new(&[3 * h], f));
let qkv = g.matmul(x, qkv_w, Shape::new(&[4, 15, 3 * h], f));
let _qkv = g.binary(BinaryOp::Add, qkv, qkv_b, Shape::new(&[4, 15, 3 * h], f));
let int_w = g.param("ffn.weight", Shape::new(&[h, int], f));
let int_b = g.param("ffn.bias", Shape::new(&[int], f));
let ffn = g.matmul(x, int_w, Shape::new(&[4, 15, int], f));
let ffn = g.binary(BinaryOp::Add, ffn, int_b, Shape::new(&[4, 15, int], f));
let ffn = g.activation(Activation::Gelu, ffn, Shape::new(&[4, 15, int], f));
let out_w = g.param("ffn_out.weight", Shape::new(&[int, h], f));
let ffn_out = g.matmul(ffn, out_w, Shape::new(&[4, 15, h], f));
g.set_outputs(vec![ffn_out]);
assert!(g.len() > 10);
println!("{g}");
}
}