use rustc_hash::FxHashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(pub(crate) usize);
impl NodeId {
pub fn index(self) -> usize {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum DataType {
#[default]
F32,
F64,
I32,
I64,
Bool,
}
impl DataType {
pub fn size_bytes(self) -> usize {
match self {
Self::F32 | Self::I32 => 4,
Self::F64 | Self::I64 => 8,
Self::Bool => 1,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Shape(pub Vec<usize>);
impl Shape {
pub fn new(dims: &[usize]) -> Self {
Self(dims.to_vec())
}
pub fn dims(&self) -> &[usize] {
&self.0
}
pub fn ndim(&self) -> usize {
self.0.len()
}
pub fn numel(&self) -> usize {
self.0.iter().product()
}
pub fn broadcast_compatible(&self, other: &Self) -> bool {
let max_ndim = self.ndim().max(other.ndim());
for i in 0..max_ndim {
let d1 = if i < self.ndim() {
self.0[self.ndim() - 1 - i]
} else {
1
};
let d2 = if i < other.ndim() {
other.0[other.ndim() - 1 - i]
} else {
1
};
if d1 != d2 && d1 != 1 && d2 != 1 {
return false;
}
}
true
}
pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
if !self.broadcast_compatible(other) {
return None;
}
let max_ndim = self.ndim().max(other.ndim());
let mut result = Vec::with_capacity(max_ndim);
for i in 0..max_ndim {
let d1 = if i < self.ndim() {
self.0[self.ndim() - 1 - i]
} else {
1
};
let d2 = if i < other.ndim() {
other.0[other.ndim() - 1 - i]
} else {
1
};
result.push(d1.max(d2));
}
result.reverse();
Some(Self(result))
}
}
impl From<&[usize]> for Shape {
fn from(dims: &[usize]) -> Self {
Self::new(dims)
}
}
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Self(dims)
}
}
#[derive(Debug, Clone, PartialEq)]
#[allow(missing_docs)]
pub enum Op {
Input { name: String },
Output { name: String, input: NodeId },
Constant { value: f64 },
Add { lhs: NodeId, rhs: NodeId },
Sub { lhs: NodeId, rhs: NodeId },
Mul { lhs: NodeId, rhs: NodeId },
Div { lhs: NodeId, rhs: NodeId },
Pow { base: NodeId, exp: NodeId },
Max { lhs: NodeId, rhs: NodeId },
Min { lhs: NodeId, rhs: NodeId },
Neg { input: NodeId },
Abs { input: NodeId },
Sqrt { input: NodeId },
Exp { input: NodeId },
Log { input: NodeId },
Sin { input: NodeId },
Cos { input: NodeId },
Tanh { input: NodeId },
Relu { input: NodeId },
Sigmoid { input: NodeId },
Gelu { input: NodeId },
Silu { input: NodeId },
AddScalar { input: NodeId, scalar: f64 },
MulScalar { input: NodeId, scalar: f64 },
Sum { input: NodeId },
SumAxis {
input: NodeId,
axis: i32,
keepdim: bool,
},
Mean { input: NodeId },
MeanAxis {
input: NodeId,
axis: i32,
keepdim: bool,
},
MaxAxis {
input: NodeId,
axis: i32,
keepdim: bool,
},
Reshape { input: NodeId, shape: Vec<isize> },
Transpose {
input: NodeId,
dim0: usize,
dim1: usize,
},
Squeeze { input: NodeId, dim: i32 },
Unsqueeze { input: NodeId, dim: i32 },
Broadcast { input: NodeId, shape: Vec<usize> },
MatMul { lhs: NodeId, rhs: NodeId },
Gt { lhs: NodeId, rhs: NodeId },
Lt { lhs: NodeId, rhs: NodeId },
Eq { lhs: NodeId, rhs: NodeId },
Where {
condition: NodeId,
x: NodeId,
y: NodeId,
},
Cast { input: NodeId, dtype: DataType },
Contiguous { input: NodeId },
}
impl Op {
pub fn inputs(&self) -> Vec<NodeId> {
match self {
Self::Input { .. } | Self::Constant { .. } => vec![],
Self::Output { input, .. }
| Self::Neg { input }
| Self::Abs { input }
| Self::Sqrt { input }
| Self::Exp { input }
| Self::Log { input }
| Self::Sin { input }
| Self::Cos { input }
| Self::Tanh { input }
| Self::Relu { input }
| Self::Sigmoid { input }
| Self::Gelu { input }
| Self::Silu { input }
| Self::AddScalar { input, .. }
| Self::MulScalar { input, .. }
| Self::Sum { input }
| Self::SumAxis { input, .. }
| Self::Mean { input }
| Self::MeanAxis { input, .. }
| Self::MaxAxis { input, .. }
| Self::Reshape { input, .. }
| Self::Transpose { input, .. }
| Self::Squeeze { input, .. }
| Self::Unsqueeze { input, .. }
| Self::Broadcast { input, .. }
| Self::Cast { input, .. }
| Self::Contiguous { input } => vec![*input],
Self::Add { lhs, rhs }
| Self::Sub { lhs, rhs }
| Self::Mul { lhs, rhs }
| Self::Div { lhs, rhs }
| Self::Pow {
base: lhs,
exp: rhs,
}
| Self::Max { lhs, rhs }
| Self::Min { lhs, rhs }
| Self::MatMul { lhs, rhs }
| Self::Gt { lhs, rhs }
| Self::Lt { lhs, rhs }
| Self::Eq { lhs, rhs } => vec![*lhs, *rhs],
Self::Where { condition, x, y } => vec![*condition, *x, *y],
}
}
pub fn is_elementwise(&self) -> bool {
matches!(
self,
Self::Add { .. }
| Self::Sub { .. }
| Self::Mul { .. }
| Self::Div { .. }
| Self::Pow { .. }
| Self::Max { .. }
| Self::Min { .. }
| Self::Neg { .. }
| Self::Abs { .. }
| Self::Sqrt { .. }
| Self::Exp { .. }
| Self::Log { .. }
| Self::Sin { .. }
| Self::Cos { .. }
| Self::Tanh { .. }
| Self::Relu { .. }
| Self::Sigmoid { .. }
| Self::Gelu { .. }
| Self::Silu { .. }
| Self::AddScalar { .. }
| Self::MulScalar { .. }
| Self::Gt { .. }
| Self::Lt { .. }
| Self::Eq { .. }
| Self::Where { .. }
)
}
pub fn is_reduction(&self) -> bool {
matches!(
self,
Self::Sum { .. }
| Self::SumAxis { .. }
| Self::Mean { .. }
| Self::MeanAxis { .. }
| Self::MaxAxis { .. }
)
}
}
#[derive(Debug, Clone)]
pub struct Node {
pub id: NodeId,
pub op: Op,
pub dtype: DataType,
pub shape: Shape,
}
#[derive(Debug, Clone)]
pub struct Graph {
nodes: Vec<Node>,
inputs: FxHashMap<String, NodeId>,
outputs: FxHashMap<String, NodeId>,
}
impl Graph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
inputs: FxHashMap::default(),
outputs: FxHashMap::default(),
}
}
pub fn add_node(&mut self, op: Op, dtype: DataType, shape: Shape) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(Node {
id,
op,
dtype,
shape,
});
id
}
pub fn register_input(&mut self, name: &str, id: NodeId) {
self.inputs.insert(name.to_string(), id);
}
pub fn register_output(&mut self, name: &str, id: NodeId) {
self.outputs.insert(name.to_string(), id);
}
pub fn node(&self, id: NodeId) -> &Node {
&self.nodes[id.0]
}
pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
&mut self.nodes[id.0]
}
pub fn nodes(&self) -> &[Node] {
&self.nodes
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn inputs(&self) -> &FxHashMap<String, NodeId> {
&self.inputs
}
pub fn outputs(&self) -> &FxHashMap<String, NodeId> {
&self.outputs
}
pub fn input(&self, name: &str) -> Option<NodeId> {
self.inputs.get(name).copied()
}
pub fn output(&self, name: &str) -> Option<NodeId> {
self.outputs.get(name).copied()
}
pub fn topological_order(&self) -> Vec<NodeId> {
(0..self.nodes.len()).map(NodeId).collect()
}
pub fn validate(&self) -> Result<(), String> {
for node in &self.nodes {
for input_id in node.op.inputs() {
if input_id.0 >= self.nodes.len() {
return Err(format!(
"Node {:?} references invalid input {:?}",
node.id, input_id
));
}
if input_id.0 >= node.id.0 {
return Err(format!(
"Node {:?} references future node {:?} (not DAG)",
node.id, input_id
));
}
}
}
for (name, id) in &self.inputs {
let node = &self.nodes[id.0];
if !matches!(node.op, Op::Input { .. }) {
return Err(format!("Input '{}' points to non-Input node", name));
}
}
Ok(())
}
}
impl Default for Graph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shape_numel() {
let shape = Shape::new(&[2, 3, 4]);
assert_eq!(shape.numel(), 24);
assert_eq!(shape.ndim(), 3);
}
#[test]
fn test_shape_broadcast() {
let s1 = Shape::new(&[2, 1, 4]);
let s2 = Shape::new(&[3, 4]);
assert!(s1.broadcast_compatible(&s2));
let result = s1.broadcast_shape(&s2).unwrap();
assert_eq!(result.dims(), &[2, 3, 4]);
}
#[test]
fn test_graph_creation() {
let mut graph = Graph::new();
let input = graph.add_node(
Op::Input {
name: "x".to_string(),
},
DataType::F32,
Shape::new(&[2, 3]),
);
graph.register_input("x", input);
let relu = graph.add_node(Op::Relu { input }, DataType::F32, Shape::new(&[2, 3]));
let output = graph.add_node(
Op::Output {
name: "y".to_string(),
input: relu,
},
DataType::F32,
Shape::new(&[2, 3]),
);
graph.register_output("y", output);
assert_eq!(graph.len(), 3);
assert!(graph.validate().is_ok());
}
#[test]
fn test_op_inputs() {
let add = Op::Add {
lhs: NodeId(0),
rhs: NodeId(1),
};
assert_eq!(add.inputs(), vec![NodeId(0), NodeId(1)]);
let relu = Op::Relu { input: NodeId(2) };
assert_eq!(relu.inputs(), vec![NodeId(2)]);
let input = Op::Input {
name: "x".to_string(),
};
assert!(input.inputs().is_empty());
}
}