1use crate::{Op, Shape};
23
24use crate::provenance::NodeOrigin;
25
26#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
29pub struct NodeId(pub u32);
30
31impl std::fmt::Display for NodeId {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "%{}", self.0)
34 }
35}
36
37#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
39#[derive(Debug, Clone)]
40pub struct Node {
41 pub id: NodeId,
42 pub op: Op,
44 pub inputs: Vec<NodeId>,
46 pub shape: Shape,
48 pub name: Option<String>,
50 pub origin: Option<NodeOrigin>,
52}
53
54#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
75#[derive(Clone, Debug)]
76pub struct Graph {
77 pub name: String,
78 nodes: Vec<Node>,
79 pub outputs: Vec<NodeId>,
81}
82
83impl PartialEq for Graph {
87 fn eq(&self, other: &Self) -> bool {
88 self.name == other.name
89 && self.nodes.len() == other.nodes.len()
90 && self.outputs == other.outputs
91 }
92}
93
94impl Graph {
95 pub fn new(name: impl Into<String>) -> Self {
96 Self {
97 name: name.into(),
98 nodes: Vec::new(),
99 outputs: Vec::new(),
100 }
101 }
102
103 pub fn len(&self) -> usize {
105 self.nodes.len()
106 }
107 pub fn is_empty(&self) -> bool {
108 self.nodes.is_empty()
109 }
110
111 pub fn node(&self, id: NodeId) -> &Node {
113 &self.nodes[id.0 as usize]
114 }
115
116 pub fn nodes(&self) -> &[Node] {
118 &self.nodes
119 }
120
121 pub fn shape(&self, id: NodeId) -> &Shape {
123 &self.nodes[id.0 as usize].shape
124 }
125
126 pub fn set_outputs(&mut self, outputs: Vec<NodeId>) {
128 self.outputs = outputs;
129 }
130
131 pub fn set_inputs(&mut self, id: NodeId, inputs: Vec<NodeId>) {
137 self.nodes[id.0 as usize].inputs = inputs;
138 }
139
140 pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
141 &mut self.nodes[id.0 as usize]
142 }
143
144 pub fn nodes_mut(&mut self) -> &mut [Node] {
145 &mut self.nodes
146 }
147
148 pub fn append_node(
154 &mut self,
155 op: Op,
156 inputs: Vec<NodeId>,
157 shape: Shape,
158 name: Option<String>,
159 ) -> NodeId {
160 self.push(op, inputs, shape, name)
161 }
162
163 pub(crate) fn push(
164 &mut self,
165 op: Op,
166 inputs: Vec<NodeId>,
167 shape: Shape,
168 name: Option<String>,
169 ) -> NodeId {
170 self.push_ext(op, inputs, shape, name, None)
171 }
172
173 pub(crate) fn push_ext(
174 &mut self,
175 op: Op,
176 inputs: Vec<NodeId>,
177 shape: Shape,
178 name: Option<String>,
179 origin: Option<NodeOrigin>,
180 ) -> NodeId {
181 let id = NodeId(self.nodes.len() as u32);
182 self.nodes.push(Node {
183 id,
184 op,
185 inputs,
186 shape,
187 name,
188 origin,
189 });
190 id
191 }
192
193 pub fn users(&self, id: NodeId) -> Vec<NodeId> {
200 self.nodes
201 .iter()
202 .filter(|n| n.inputs.contains(&id))
203 .map(|n| n.id)
204 .collect()
205 }
206
207 pub fn use_count(&self, id: NodeId) -> usize {
209 self.nodes.iter().filter(|n| n.inputs.contains(&id)).count()
210 }
211
212 pub fn topo_order(&self) -> impl Iterator<Item = NodeId> + '_ {
214 (0..self.nodes.len()).map(|i| NodeId(i as u32))
215 }
216
217 pub fn reverse_topo(&self) -> impl Iterator<Item = NodeId> + '_ {
219 (0..self.nodes.len()).rev().map(|i| NodeId(i as u32))
220 }
221
222 pub fn define(
229 name: impl Into<String>,
230 build: impl FnOnce(&mut crate::hir::HirModule) -> crate::hir::HirNodeId,
231 ) -> crate::GraphModule {
232 crate::GraphModule::define(name, build)
233 }
234
235 pub fn hir(name: impl Into<String>) -> crate::GraphModule {
237 crate::GraphModule::hir(name)
238 }
239
240 pub fn module(self) -> crate::GraphModule {
242 crate::GraphModule::from_graph(self)
243 }
244
245 pub fn from_hir(hir: crate::hir::HirModule) -> Result<Self, crate::hir::LowerError> {
247 hir.lower_to_mir().map(|m| m.into_graph())
248 }
249
250 pub fn to_mir(self) -> crate::MirModule {
252 crate::MirModule::from_graph(self)
253 }
254
255 pub fn from_lir(lir: crate::LirModule) -> Self {
257 lir.into_graph()
258 }
259
260 pub fn inspect(&self) -> String {
262 crate::inspect_graph(self)
263 }
264
265 pub fn has_dynamic_dims(&self) -> bool {
267 crate::dynamic::has_dynamic_dims(self)
268 }
269
270 pub fn dynamic_symbols(&self) -> Vec<u32> {
272 crate::dynamic::collect_dynamic_symbols(self)
273 }
274
275 pub fn bind(&self, bindings: &crate::DimBinding) -> Self {
277 crate::dynamic::bind_graph(self, bindings)
278 }
279
280 pub fn inspect_module(module: &crate::GraphModule) -> String {
282 module.inspect()
283 }
284}
285
286impl std::fmt::Display for Graph {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 writeln!(f, "graph @{} {{", self.name)?;
290 for node in &self.nodes {
291 write!(f, " {} = {}", node.id, node.op)?;
292 if !node.inputs.is_empty() {
293 write!(f, "(")?;
294 for (i, inp) in node.inputs.iter().enumerate() {
295 if i > 0 {
296 write!(f, ", ")?;
297 }
298 write!(f, "{inp}")?;
299 }
300 write!(f, ")")?;
301 }
302 writeln!(f, " : {}", node.shape)?;
303 }
304 if !self.outputs.is_empty() {
305 write!(f, " return ")?;
306 for (i, o) in self.outputs.iter().enumerate() {
307 if i > 0 {
308 write!(f, ", ")?;
309 }
310 write!(f, "{o}")?;
311 }
312 writeln!(f)?;
313 }
314 writeln!(f, "}}")
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::{
322 DType,
323 op::{Activation, BinaryOp},
324 };
325
326 #[test]
327 fn build_simple_graph() {
328 let mut g = Graph::new("test");
329
330 let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
331 let w = g.param("weight", Shape::new(&[384, 1536], DType::F32));
332 let b = g.param("bias", Shape::new(&[1536], DType::F32));
333
334 let mm = g.matmul(x, w, Shape::new(&[4, 15, 1536], DType::F32));
335 let add = g.binary(BinaryOp::Add, mm, b, Shape::new(&[4, 15, 1536], DType::F32));
336 let out = g.activation(
337 Activation::Gelu,
338 add,
339 Shape::new(&[4, 15, 1536], DType::F32),
340 );
341
342 g.set_outputs(vec![out]);
343
344 assert_eq!(g.len(), 6);
345 assert_eq!(g.use_count(mm), 1); assert_eq!(g.use_count(x), 1); let printed = format!("{g}");
349 assert!(printed.contains("matmul(%0, %1)"));
350 assert!(printed.contains("Gelu(%4)"));
351 assert!(printed.contains("return %5"));
352 }
353
354 #[test]
356 fn bert_layer_graph() {
357 let mut g = Graph::new("bert_layer");
358 let f = DType::F32;
359 let h = 384;
360 let int = 1536;
361
362 let x = g.input("hidden", Shape::new(&[4, 15, h], f));
364
365 let qkv_w = g.param("qkv.weight", Shape::new(&[h, 3 * h], f));
367 let qkv_b = g.param("qkv.bias", Shape::new(&[3 * h], f));
368 let qkv = g.matmul(x, qkv_w, Shape::new(&[4, 15, 3 * h], f));
369 let _qkv = g.binary(BinaryOp::Add, qkv, qkv_b, Shape::new(&[4, 15, 3 * h], f));
370
371 let int_w = g.param("ffn.weight", Shape::new(&[h, int], f));
375 let int_b = g.param("ffn.bias", Shape::new(&[int], f));
376 let ffn = g.matmul(x, int_w, Shape::new(&[4, 15, int], f));
377 let ffn = g.binary(BinaryOp::Add, ffn, int_b, Shape::new(&[4, 15, int], f));
378 let ffn = g.activation(Activation::Gelu, ffn, Shape::new(&[4, 15, int], f));
379
380 let out_w = g.param("ffn_out.weight", Shape::new(&[int, h], f));
381 let ffn_out = g.matmul(ffn, out_w, Shape::new(&[4, 15, h], f));
382
383 g.set_outputs(vec![ffn_out]);
384
385 assert!(g.len() > 10);
386 println!("{g}");
387 }
388}