1use crate::dtype::DType;
2use crate::node::Node;
3use crate::shape::Shape;
4use crate::tensor::Id;
5use std::collections::BTreeSet;
6
7pub trait SizedIterator: Iterator + Sized {
9 fn make_sized(self, len: usize) -> SizedIter<Self::Item, Self> {
11 SizedIter { iter: self, len }
12 }
13}
14
15impl<IT: Iterator> SizedIterator for IT {}
16
17pub struct SizedIter<T, IT: Iterator<Item = T>> {
19 iter: IT,
20 len: usize,
21}
22
23impl<T, IT: Iterator<Item = T>> Iterator for SizedIter<T, IT> {
24 type Item = T;
25 fn next(&mut self) -> Option<Self::Item> {
26 self.iter.next()
27 }
28}
29
30impl<T, IT: Iterator<Item = T>> ExactSizeIterator for SizedIter<T, IT> {
31 fn len(&self) -> usize {
32 self.len
33 }
34}
35
36pub fn get_shape(nodes: &[Node], mut x: Id) -> &Shape {
38 loop {
39 let node = &nodes[x.i()];
40 match node {
41 Node::Leaf(shape, ..)
42 | Node::Uniform(shape, ..)
43 | Node::Reshape(_, shape)
44 | Node::Expand(_, shape)
45 | Node::Permute(.., shape)
46 | Node::Pad(.., shape)
47 | Node::Sum(.., shape)
48 | Node::Max(.., shape) => return shape,
49 _ => x = node.parameters().next().unwrap(),
50 }
51 }
52}
53
54pub fn get_dtype(nodes: &[Node], mut x: Id) -> DType {
56 loop {
57 let node = &nodes[x.i()];
58 match node {
59 Node::Leaf(_, dtype) | Node::Uniform(_, dtype) | Node::Cast(_, dtype) => return *dtype,
60 _ => x = node.parameters().next().unwrap(),
61 }
62 }
63}
64
65pub fn plot_graph_dot(ids: &BTreeSet<Id>, nodes: &[Node], rcs: &[u32]) -> alloc::string::String {
67 use alloc::{format, string::String};
69 use core::fmt::Write;
70 let mut user_rc = rcs.to_vec();
71 for (i, node) in nodes.iter().enumerate() {
72 if rcs[i] > 0 {
74 for param in node.parameters() {
75 user_rc[param.i()] -= 1;
76 }
77 }
78 }
79 let mut res = String::from("strict digraph {\n ordering=in\n rank=source\n");
81 let mut add_node = |i: usize, text: &str, shape: &str| {
82 let fillcolor = if user_rc[i] > 0 { "lightblue" } else { "grey" };
83 write!(
88 res,
89 " {i}[label=\"{} x {}NL{}NL{}\", shape={}, fillcolor=\"{}\", style=filled]",
90 crate::tensor::id(i),
91 rcs[i],
92 text,
93 get_shape(nodes, crate::tensor::id(i)),
94 shape,
95 fillcolor
96 )
97 .unwrap();
98 writeln!(res).unwrap();
99 };
100 let mut edges = String::new();
101 for id in ids {
102 let id = id.i();
103 let node = &nodes[id];
104 match node {
105 Node::Leaf(sh, dtype) => add_node(id, &format!("Leaf({sh}, {dtype})"), "box"),
106 Node::Uniform(sh, dtype) => add_node(id, &format!("Uniform({sh}, {dtype})"), "box"),
107 Node::Add(x, y) => add_node(id, &format!("Add({x}, {y})"), "oval"),
108 Node::Sub(x, y) => add_node(id, &format!("Sub({x}, {y})"), "oval"),
109 Node::Mul(x, y) => add_node(id, &format!("Mul({x}, {y})"), "oval"),
110 Node::Div(x, y) => add_node(id, &format!("Div({x}, {y})"), "oval"),
111 Node::Cmplt(x, y) => add_node(id, &format!("Cmplt({x}, {y})"), "oval"),
112 Node::Where(x, y, z) => add_node(id, &format!("Cmplt({x}, {y}, {z})"), "oval"),
113 Node::Pow(x, y) => add_node(id, &format!("Pow({x}, {y})"), "oval"),
114 Node::Detach(x) => add_node(id, &format!("Detach({x})"), "oval"),
115 Node::Neg(x) => add_node(id, &format!("Neg({x})"), "oval"),
116 Node::Exp(x) => add_node(id, &format!("Exp({x})"), "oval"),
117 Node::ReLU(x) => add_node(id, &format!("ReLU({x})"), "oval"),
118 Node::Ln(x) => add_node(id, &format!("Ln({x})"), "oval"),
119 Node::Sin(x) => add_node(id, &format!("Sin({x})"), "oval"),
120 Node::Cos(x) => add_node(id, &format!("Cos({x})"), "oval"),
121 Node::Sqrt(x) => add_node(id, &format!("Sqrt({x})"), "oval"),
122 Node::Tanh(x) => add_node(id, &format!("Tanh({x})"), "oval"),
123 Node::Expand(x, ..) => add_node(id, &format!("Expand({x})"), "oval"),
124 Node::Pad(x, padding, ..) => add_node(id, &format!("Pad({x}, {padding:?})"), "oval"),
125 Node::Cast(x, dtype) => add_node(id, &format!("CastI32({x}, {dtype})"), "oval"),
126 Node::Reshape(x, ..) => add_node(id, &format!("Reshape({x})"), "oval"),
127 Node::Permute(x, axes, ..) => add_node(id, &format!("Permute({x}, {axes:?})"), "oval"),
128 Node::Sum(x, axes, ..) => add_node(id, &format!("Sum({x}, {axes:?})"), "oval"),
129 Node::Max(x, axes, ..) => add_node(id, &format!("Max({x}, {axes:?})"), "oval"),
130 }
131 for param in node.parameters() {
132 writeln!(edges, " {} -> {id}", param.i()).unwrap();
133 }
134 }
135 res = res.replace("NL", "\n");
136 write!(res, "{edges}}}").unwrap();
137 res
138}