zyx_core/
utils.rs

1use crate::dtype::DType;
2use crate::node::Node;
3use crate::shape::Shape;
4use crate::tensor::Id;
5use std::collections::BTreeSet;
6
7/// Sized iterator
8pub trait SizedIterator: Iterator + Sized {
9    /// Manually add exact size to any iterator
10    fn make_sized(self, len: usize) -> SizedIter<Self::Item, Self> {
11        SizedIter { iter: self, len }
12    }
13}
14
15impl<IT: Iterator> SizedIterator for IT {}
16
17/// Sized iterator
18pub struct SizedIter<T, IT: Iterator<Item = T>> {
19    iter: IT,
20    len: usize,
21}
22
23impl<T, IT: Iterator<Item = T>> Iterator for SizedIter<T, IT> {
24    type Item = T;
25    fn next(&mut self) -> Option<Self::Item> {
26        self.iter.next()
27    }
28}
29
30impl<T, IT: Iterator<Item = T>> ExactSizeIterator for SizedIter<T, IT> {
31    fn len(&self) -> usize {
32        self.len
33    }
34}
35
36/// Recursive search to get shape of x in nodes
37pub fn get_shape(nodes: &[Node], mut x: Id) -> &Shape {
38    loop {
39        let node = &nodes[x.i()];
40        match node {
41            Node::Leaf(shape, ..)
42            | Node::Uniform(shape, ..)
43            | Node::Reshape(_, shape)
44            | Node::Expand(_, shape)
45            | Node::Permute(.., shape)
46            | Node::Pad(.., shape)
47            | Node::Sum(.., shape)
48            | Node::Max(.., shape) => return shape,
49            _ => x = node.parameters().next().unwrap(),
50        }
51    }
52}
53
54/// Recursive search to get dtype of x in nodes
55pub fn get_dtype(nodes: &[Node], mut x: Id) -> DType {
56    loop {
57        let node = &nodes[x.i()];
58        match node {
59            Node::Leaf(_, dtype) | Node::Uniform(_, dtype) | Node::Cast(_, dtype) => return *dtype,
60            _ => x = node.parameters().next().unwrap(),
61        }
62    }
63}
64
65/// Puts graph of nodes into dot language for visualization
66pub fn plot_graph_dot(ids: &BTreeSet<Id>, nodes: &[Node], rcs: &[u32]) -> alloc::string::String {
67    //let ids = &(0..nodes.len()).map(crate::tensor::id).collect::<BTreeSet<Id>>();
68    use alloc::{format, string::String};
69    use core::fmt::Write;
70    let mut user_rc = rcs.to_vec();
71    for (i, node) in nodes.iter().enumerate() {
72        // not all nodes are alive :)
73        if rcs[i] > 0 {
74            for param in node.parameters() {
75                user_rc[param.i()] -= 1;
76            }
77        }
78    }
79    //std::println!("User {:?}", user_rc);
80    let mut res = String::from("strict digraph {\n  ordering=in\n  rank=source\n");
81    let mut add_node = |i: usize, text: &str, shape: &str| {
82        let fillcolor = if user_rc[i] > 0 { "lightblue" } else { "grey" };
83        /*if let Some(label) = labels.get(&NodeId::new(id)) {
84            write!(res, "  {id}[label=\"{}NL{} x {}NL{}NL{}\", shape={}, fillcolor=\"{}\", style=filled]",
85                label, id, rc[id], text, get_shape(NodeId::new(id)), shape, fillcolor).unwrap();
86        } else {*/
87        write!(
88            res,
89            "  {i}[label=\"{} x {}NL{}NL{}\", shape={}, fillcolor=\"{}\", style=filled]",
90            crate::tensor::id(i),
91            rcs[i],
92            text,
93            get_shape(nodes, crate::tensor::id(i)),
94            shape,
95            fillcolor
96        )
97        .unwrap();
98        writeln!(res).unwrap();
99    };
100    let mut edges = String::new();
101    for id in ids {
102        let id = id.i();
103        let node = &nodes[id];
104        match node {
105            Node::Leaf(sh, dtype) => add_node(id, &format!("Leaf({sh}, {dtype})"), "box"),
106            Node::Uniform(sh, dtype) => add_node(id, &format!("Uniform({sh}, {dtype})"), "box"),
107            Node::Add(x, y) => add_node(id, &format!("Add({x}, {y})"), "oval"),
108            Node::Sub(x, y) => add_node(id, &format!("Sub({x}, {y})"), "oval"),
109            Node::Mul(x, y) => add_node(id, &format!("Mul({x}, {y})"), "oval"),
110            Node::Div(x, y) => add_node(id, &format!("Div({x}, {y})"), "oval"),
111            Node::Cmplt(x, y) => add_node(id, &format!("Cmplt({x}, {y})"), "oval"),
112            Node::Where(x, y, z) => add_node(id, &format!("Cmplt({x}, {y}, {z})"), "oval"),
113            Node::Pow(x, y) => add_node(id, &format!("Pow({x}, {y})"), "oval"),
114            Node::Detach(x) => add_node(id, &format!("Detach({x})"), "oval"),
115            Node::Neg(x) => add_node(id, &format!("Neg({x})"), "oval"),
116            Node::Exp(x) => add_node(id, &format!("Exp({x})"), "oval"),
117            Node::ReLU(x) => add_node(id, &format!("ReLU({x})"), "oval"),
118            Node::Ln(x) => add_node(id, &format!("Ln({x})"), "oval"),
119            Node::Sin(x) => add_node(id, &format!("Sin({x})"), "oval"),
120            Node::Cos(x) => add_node(id, &format!("Cos({x})"), "oval"),
121            Node::Sqrt(x) => add_node(id, &format!("Sqrt({x})"), "oval"),
122            Node::Tanh(x) => add_node(id, &format!("Tanh({x})"), "oval"),
123            Node::Expand(x, ..) => add_node(id, &format!("Expand({x})"), "oval"),
124            Node::Pad(x, padding, ..) => add_node(id, &format!("Pad({x}, {padding:?})"), "oval"),
125            Node::Cast(x, dtype) => add_node(id, &format!("CastI32({x}, {dtype})"), "oval"),
126            Node::Reshape(x, ..) => add_node(id, &format!("Reshape({x})"), "oval"),
127            Node::Permute(x, axes, ..) => add_node(id, &format!("Permute({x}, {axes:?})"), "oval"),
128            Node::Sum(x, axes, ..) => add_node(id, &format!("Sum({x}, {axes:?})"), "oval"),
129            Node::Max(x, axes, ..) => add_node(id, &format!("Max({x}, {axes:?})"), "oval"),
130        }
131        for param in node.parameters() {
132            writeln!(edges, "  {} -> {id}", param.i()).unwrap();
133        }
134    }
135    res = res.replace("NL", "\n");
136    write!(res, "{edges}}}").unwrap();
137    res
138}