Skip to main content

tract_libcli/
draw.rs

1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use box_drawing::heavy::*;
4use nu_ansi_term::{Color, Style};
5use std::fmt;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9#[derive(Clone)]
10pub struct Wire {
11    pub outlet: OutletId,
12    pub color: Option<Style>,
13    pub should_change_color: bool,
14    pub successors: Vec<InletId>,
15}
16
17impl fmt::Debug for Wire {
18    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
19        let s = format!("{:?} {:?}", self.outlet, self.successors);
20        if let Some(c) = self.color { write!(fmt, "{}", c.paint(s)) } else { write!(fmt, "{s}") }
21    }
22}
23
24#[derive(Clone, Default)]
25pub struct DrawingState {
26    pub current_color: Style,
27    pub latest_node_color: Style,
28    pub wires: Vec<Wire>,
29}
30
31impl DrawingState {
32    fn current_color(&self) -> Style {
33        self.current_color
34    }
35
36    fn next_color(&mut self) -> Style {
37        let colors = &[
38            Color::Red.normal(),
39            Color::Green.normal(),
40            Color::Yellow.normal(),
41            Color::Blue.normal(),
42            Color::Purple.normal(),
43            Color::Cyan.normal(),
44            Color::White.normal(),
45            Color::Red.bold(),
46            Color::Green.bold(),
47            Color::Yellow.bold(),
48            Color::Blue.bold(),
49            Color::Purple.bold(),
50            Color::Cyan.bold(),
51            Color::White.bold(),
52        ];
53        let color = colors
54            .iter()
55            .min_by_key(|&c| self.wires.iter().filter(|w| w.color == Some(*c)).count())
56            .unwrap();
57        self.current_color = *color;
58        *color
59    }
60
61    fn inputs_to_draw(&self, model: &dyn Model, node: usize) -> Vec<OutletId> {
62        model.node_inputs(node).to_vec()
63    }
64
65    fn passthrough_count(&self, node: usize) -> usize {
66        self.wires.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
67    }
68
69    pub fn draw_node_vprefix(
70        &mut self,
71        model: &dyn Model,
72        node: usize,
73        _opts: &DisplayParams,
74    ) -> TractResult<Vec<String>> {
75        let mut lines = vec![String::new()];
76        macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
77        macro_rules! ln {
78            () => {
79                lines.push(String::new())
80            };
81        }
82        let passthrough_count = self.passthrough_count(node);
83        /*
84        println!("\n{}", model.node_format(node));
85        for (ix, w) in self.wires.iter().enumerate() {
86            println!(" * {} {:?}", ix, w);
87        }
88        */
89        for (ix, &input) in model.node_inputs(node).iter().enumerate().rev() {
90            let wire = self.wires.iter().position(|o| o.outlet == input).unwrap();
91            let wanted = passthrough_count + ix;
92            if wire != wanted {
93                let little = wire.min(wanted);
94                let big = wire.max(wanted);
95                let moving = self.wires[little].clone();
96                let must_clone = moving.successors.iter().any(|i| i.node != node);
97                let offset = self
98                    .wires
99                    .iter()
100                    .skip(little + 1)
101                    .take(big - little)
102                    .filter(|w| w.color.is_some())
103                    .count()
104                    + must_clone as usize;
105                // println!("{}->{} (offset: {})", little, big, offset);
106                #[allow(clippy::unnecessary_unwrap)]
107                if moving.color.is_some() && offset != 0 {
108                    let color = moving.color.unwrap();
109                    for w in &self.wires[0..little] {
110                        if let Some(c) = w.color {
111                            p!("{}", c.paint(VERTICAL));
112                        }
113                    }
114                    // println!("offset: {}", offset);
115                    p!("{}", color.paint(if must_clone { VERTICAL_RIGHT } else { UP_RIGHT }));
116                    for _ in 0..offset - 1 {
117                        p!("{}", color.paint(HORIZONTAL));
118                    }
119                    p!("{}", color.paint(DOWN_LEFT));
120                }
121                while self.wires.len() <= big {
122                    self.wires.push(Wire { successors: vec![], ..self.wires[little] });
123                }
124                if must_clone {
125                    self.wires[little].successors.retain(|&i| i != InletId::new(node, ix));
126                    self.wires[big] = Wire {
127                        successors: vec![InletId::new(node, ix)],
128                        should_change_color: true,
129                        ..self.wires[little]
130                    };
131                } else {
132                    for i in little..big {
133                        self.wires.swap(i, i + 1);
134                    }
135                }
136                if moving.color.is_some() {
137                    if big < self.wires.len() {
138                        for w in &self.wires[big + 1..] {
139                            if let Some(c) = w.color {
140                                p!("{}", c.paint(VERTICAL));
141                            } else {
142                                p!(" ");
143                            }
144                        }
145                    }
146                    ln!();
147                }
148            }
149        }
150        while lines.last().map(|s| s.trim()) == Some("") {
151            lines.pop();
152        }
153        Ok(lines)
154    }
155
156    pub fn draw_node_body(
157        &mut self,
158        model: &dyn Model,
159        node: usize,
160        opts: &DisplayParams,
161    ) -> TractResult<Vec<String>> {
162        let mut lines = vec![String::new()];
163        macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
164        macro_rules! ln {
165            () => {
166                lines.push(String::new())
167            };
168        }
169        let inputs = self.inputs_to_draw(model, node);
170        let passthrough_count = self.passthrough_count(node);
171        let display = opts.konst || !model.node_const(node);
172        if display {
173            for wire in &self.wires[0..passthrough_count] {
174                if let Some(color) = wire.color {
175                    p!("{}", color.paint(VERTICAL));
176                }
177            }
178        }
179        let node_output_count = model.node_output_count(node);
180        if display {
181            self.latest_node_color = if !inputs.is_empty() {
182                let wire0 = &self.wires[passthrough_count];
183                #[allow(clippy::unnecessary_unwrap)]
184                if wire0.color.is_some() && !wire0.should_change_color {
185                    wire0.color.unwrap()
186                } else {
187                    self.next_color()
188                }
189            } else {
190                self.next_color()
191            };
192            match (inputs.len(), node_output_count) {
193                (0, 1) => {
194                    p!("{}", self.latest_node_color.paint(DOWN_RIGHT));
195                }
196                (1, 0) => {
197                    p!("{}", self.latest_node_color.paint("╹"));
198                }
199                (u, d) => {
200                    p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
201                    for _ in 1..u.min(d) {
202                        p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
203                    }
204                    for _ in u..d {
205                        p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
206                    }
207                    for _ in d..u {
208                        p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
209                    }
210                }
211            }
212            ln!();
213        }
214        while lines.last().map(|s| s.trim()) == Some("") {
215            lines.pop();
216        }
217        Ok(lines)
218    }
219
220    pub fn draw_node_vfiller(&self, model: &dyn Model, node: usize) -> TractResult<String> {
221        let mut s = String::new();
222        for wire in &self.wires {
223            if let Some(color) = wire.color {
224                write!(&mut s, "{}", color.paint(VERTICAL))?;
225            }
226        }
227        for _ in self.wires.len()..model.node_output_count(node) {
228            write!(&mut s, " ")?;
229        }
230        Ok(s)
231    }
232
233    pub fn draw_node_vsuffix(
234        &mut self,
235        model: &dyn Model,
236        node: usize,
237        opts: &DisplayParams,
238    ) -> TractResult<Vec<String>> {
239        let mut lines = vec![];
240        let passthrough_count = self.passthrough_count(node);
241        let node_output_count = model.node_output_count(node);
242        let node_color = self
243            .wires
244            .get(passthrough_count)
245            .map(|w| w.color)
246            .unwrap_or_else(|| Some(self.current_color()));
247        self.wires.truncate(passthrough_count);
248        for slot in 0..node_output_count {
249            let outlet = OutletId::new(node, slot);
250            let successors = model.outlet_successors(outlet).to_vec();
251            let color = if !opts.konst && model.node_const(node) {
252                None
253            } else if slot == 0 && node_color.is_some() {
254                Some(self.latest_node_color)
255            } else {
256                Some(self.next_color())
257            };
258            self.wires.push(Wire { outlet, color, successors, should_change_color: false });
259        }
260        let wires_before = self.wires.clone();
261        self.wires.retain(|w| !w.successors.is_empty());
262        for (wanted_at, w) in self.wires.iter().enumerate() {
263            let is_at = wires_before.iter().position(|w2| w.outlet == w2.outlet).unwrap();
264            if wanted_at < is_at {
265                let mut s = String::new();
266                for w in 0..wanted_at {
267                    if let Some(color) = self.wires[w].color {
268                        write!(&mut s, "{}", color.paint(VERTICAL))?;
269                    }
270                }
271                if let Some(color) = self.wires[wanted_at].color {
272                    write!(&mut s, "{}", color.paint(DOWN_RIGHT))?;
273                    for w in is_at + 1..wanted_at {
274                        if self.wires[w].color.is_some() {
275                            write!(&mut s, "{}", color.paint(HORIZONTAL))?;
276                        }
277                    }
278                    write!(&mut s, "{}", color.paint(UP_LEFT))?;
279                    for w in is_at..self.wires.len() {
280                        if let Some(color) = self.wires[w].color {
281                            write!(&mut s, "{}", color.paint(VERTICAL))?;
282                        }
283                    }
284                }
285                lines.push(s);
286            }
287        }
288        // println!("{:?}", self.wires);
289        while lines.last().map(|s| s.trim()) == Some("") {
290            lines.pop();
291        }
292        Ok(lines)
293    }
294}