1use crate::display_params::DisplayParams;
2use crate::model::Model;
3use nu_ansi_term::{Color, Style};
4use box_drawing::heavy::*;
5use std::fmt;
6use std::fmt::Write;
7use tract_core::internal::*;
8
9#[derive(Clone)]
10pub struct Wire {
11 pub outlet: OutletId,
12 pub color: Option<Style>,
13 pub should_change_color: bool,
14 pub successors: Vec<InletId>,
15}
16
17impl fmt::Debug for Wire {
18 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
19 let s = format!("{:?} {:?}", self.outlet, self.successors);
20 if let Some(c) = self.color {
21 write!(fmt, "{}", c.paint(s))
22 } else {
23 write!(fmt, "{s}")
24 }
25 }
26}
27
28#[derive(Clone, Default)]
29pub struct DrawingState {
30 pub current_color: Style,
31 pub latest_node_color: Style,
32 pub wires: Vec<Wire>,
33}
34
35impl DrawingState {
36 fn current_color(&self) -> Style {
37 self.current_color
38 }
39
40 fn next_color(&mut self) -> Style {
41 let colors = &[
42 Color::Red.normal(),
43 Color::Green.normal(),
44 Color::Yellow.normal(),
45 Color::Blue.normal(),
46 Color::Purple.normal(),
47 Color::Cyan.normal(),
48 Color::White.normal(),
49 Color::Red.bold(),
50 Color::Green.bold(),
51 Color::Yellow.bold(),
52 Color::Blue.bold(),
53 Color::Purple.bold(),
54 Color::Cyan.bold(),
55 Color::White.bold(),
56 ];
57 let color = colors
58 .iter()
59 .min_by_key(|&c| self.wires.iter().filter(|w| w.color == Some(*c)).count())
60 .unwrap();
61 self.current_color = *color;
62 *color
63 }
64
65 fn inputs_to_draw(&self, model: &dyn Model, node: usize) -> Vec<OutletId> {
66 model.node_inputs(node).to_vec()
67 }
68
69 fn passthrough_count(&self, node: usize) -> usize {
70 self.wires.iter().filter(|w| w.successors.iter().any(|i| i.node != node)).count()
71 }
72
73 pub fn draw_node_vprefix(
74 &mut self,
75 model: &dyn Model,
76 node: usize,
77 _opts: &DisplayParams,
78 ) -> TractResult<Vec<String>> {
79 let mut lines = vec![String::new()];
80 macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
81 macro_rules! ln {
82 () => {
83 lines.push(String::new())
84 };
85 }
86 let passthrough_count = self.passthrough_count(node);
87 for (ix, &input) in model.node_inputs(node).iter().enumerate().rev() {
94 let wire = self.wires.iter().position(|o| o.outlet == input).unwrap();
95 let wanted = passthrough_count + ix;
96 if wire != wanted {
97 let little = wire.min(wanted);
98 let big = wire.max(wanted);
99 let moving = self.wires[little].clone();
100 let must_clone = moving.successors.iter().any(|i| i.node != node);
101 let offset = self
102 .wires
103 .iter()
104 .skip(little + 1)
105 .take(big - little)
106 .filter(|w| w.color.is_some())
107 .count()
108 + must_clone as usize;
109 if moving.color.is_some() && offset != 0 {
111 let color = moving.color.unwrap();
112 for w in &self.wires[0..little] {
113 if let Some(c) = w.color {
114 p!("{}", c.paint(VERTICAL));
115 }
116 }
117 p!("{}", color.paint(if must_clone { VERTICAL_RIGHT } else { UP_RIGHT }));
119 for _ in 0..offset - 1 {
120 p!("{}", color.paint(HORIZONTAL));
121 }
122 p!("{}", color.paint(DOWN_LEFT));
123 }
124 while self.wires.len() <= big {
125 self.wires.push(Wire { successors: vec![], ..self.wires[little] });
126 }
127 if must_clone {
128 self.wires[little].successors.retain(|&i| i != InletId::new(node, ix));
129 self.wires[big] = Wire {
130 successors: vec![InletId::new(node, ix)],
131 should_change_color: true,
132 ..self.wires[little]
133 };
134 } else {
135 for i in little..big {
136 self.wires.swap(i, i + 1);
137 }
138 }
139 if moving.color.is_some() {
140 if big < self.wires.len() {
141 for w in &self.wires[big + 1..] {
142 if let Some(c) = w.color {
143 p!("{}", c.paint(VERTICAL));
144 } else {
145 p!(" ");
146 }
147 }
148 }
149 ln!();
150 }
151 }
152 }
153 while lines.last().map(|s| s.trim()) == Some("") {
154 lines.pop();
155 }
156 Ok(lines)
157 }
158
159 pub fn draw_node_body(
160 &mut self,
161 model: &dyn Model,
162 node: usize,
163 opts: &DisplayParams,
164 ) -> TractResult<Vec<String>> {
165 let mut lines = vec![String::new()];
166 macro_rules! p { ($($args: expr),*) => { write!(lines.last_mut().unwrap(), $($args),*)?;} }
167 macro_rules! ln {
168 () => {
169 lines.push(String::new())
170 };
171 }
172 let inputs = self.inputs_to_draw(model, node);
173 let passthrough_count = self.passthrough_count(node);
174 let display = opts.konst || !model.node_const(node);
175 if display {
176 for wire in &self.wires[0..passthrough_count] {
177 if let Some(color) = wire.color {
178 p!("{}", color.paint(VERTICAL));
179 }
180 }
181 }
182 let node_output_count = model.node_output_count(node);
183 if display {
184 self.latest_node_color = if !inputs.is_empty() {
185 let wire0 = &self.wires[passthrough_count];
186 if wire0.color.is_some() && !wire0.should_change_color {
187 wire0.color.unwrap()
188 } else {
189 self.next_color()
190 }
191 } else {
192 self.next_color()
193 };
194 match (inputs.len(), node_output_count) {
195 (0, 1) => {
196 p!("{}", self.latest_node_color.paint(DOWN_RIGHT));
197 }
198 (1, 0) => {
199 p!("{}", self.latest_node_color.paint("╹"));
200 }
201 (u, d) => {
202 p!("{}", self.latest_node_color.paint(VERTICAL_RIGHT));
203 for _ in 1..u.min(d) {
204 p!("{}", self.latest_node_color.paint(VERTICAL_HORIZONTAL));
205 }
206 for _ in u..d {
207 p!("{}", self.latest_node_color.paint(DOWN_HORIZONTAL));
208 }
209 for _ in d..u {
210 p!("{}", self.latest_node_color.paint(UP_HORIZONTAL));
211 }
212 }
213 }
214 ln!();
215 }
216 while lines.last().map(|s| s.trim()) == Some("") {
217 lines.pop();
218 }
219 Ok(lines)
220 }
221
222 pub fn draw_node_vfiller(&self, model: &dyn Model, node: usize) -> TractResult<String> {
223 let mut s = String::new();
224 for wire in &self.wires {
225 if let Some(color) = wire.color {
226 write!(&mut s, "{}", color.paint(VERTICAL))?;
227 }
228 }
229 for _ in self.wires.len()..model.node_output_count(node) {
230 write!(&mut s, " ")?;
231 }
232 Ok(s)
233 }
234
235 pub fn draw_node_vsuffix(
236 &mut self,
237 model: &dyn Model,
238 node: usize,
239 opts: &DisplayParams,
240 ) -> TractResult<Vec<String>> {
241 let mut lines = vec![];
242 let passthrough_count = self.passthrough_count(node);
243 let node_output_count = model.node_output_count(node);
244 let node_color = self
245 .wires
246 .get(passthrough_count)
247 .map(|w| w.color)
248 .unwrap_or_else(|| Some(self.current_color()));
249 self.wires.truncate(passthrough_count);
250 for slot in 0..node_output_count {
251 let outlet = OutletId::new(node, slot);
252 let successors = model.outlet_successors(outlet).to_vec();
253 let color = if !opts.konst && model.node_const(node) {
254 None
255 } else if slot == 0 && node_color.is_some() {
256 Some(self.latest_node_color)
257 } else {
258 Some(self.next_color())
259 };
260 self.wires.push(Wire { outlet, color, successors, should_change_color: false });
261 }
262 let wires_before = self.wires.clone();
263 self.wires.retain(|w| !w.successors.is_empty());
264 for (wanted_at, w) in self.wires.iter().enumerate() {
265 let is_at = wires_before.iter().position(|w2| w.outlet == w2.outlet).unwrap();
266 if wanted_at < is_at {
267 let mut s = String::new();
268 for w in 0..wanted_at {
269 if let Some(color) = self.wires[w].color {
270 write!(&mut s, "{}", color.paint(VERTICAL))?;
271 }
272 }
273 if let Some(color) = self.wires[wanted_at].color {
274 write!(&mut s, "{}", color.paint(DOWN_RIGHT))?;
275 for w in is_at + 1..wanted_at {
276 if self.wires[w].color.is_some() {
277 write!(&mut s, "{}", color.paint(HORIZONTAL))?;
278 }
279 }
280 write!(&mut s, "{}", color.paint(UP_LEFT))?;
281 for w in is_at..self.wires.len() {
282 if let Some(color) = self.wires[w].color {
283 write!(&mut s, "{}", color.paint(VERTICAL))?;
284 }
285 }
286 }
287 lines.push(s);
288 }
289 }
290 while lines.last().map(|s| s.trim()) == Some("") {
292 lines.pop();
293 }
294 Ok(lines)
295 }
296}