1use crate::ir::hlo::{HloNode, HloOp};
2
3#[derive(Clone, Debug)]
5pub struct IRGraph {
6 pub nodes: Vec<HloNode>,
7}
8
9impl IRGraph {
10 pub fn new() -> Self { Self { nodes: vec![] } }
11
12 pub fn push(&mut self, node: HloNode) -> usize {
14 let id = node.id;
15 self.nodes.push(node);
16 id
17 }
18
19 pub fn binary_op(op: HloOp, left_shape: Vec<usize>, right_shape: Vec<usize>) -> Self {
21 let shape = if left_shape.len() >= right_shape.len() { left_shape.clone() } else { right_shape.clone() };
23 let mut g = IRGraph::new();
24 g.nodes.push(HloNode::new(0, op.clone(), vec![], left_shape));
26 g.nodes.push(HloNode::new(1, op.clone(), vec![], right_shape));
27 g.nodes.push(HloNode::new(2, op, vec![0,1], shape));
28 g
29 }
30}