1use std::cmp::max;
2use std::fmt::{Debug, Formatter};
3use std::iter::zip;
4
5use image::{ImageBuffer, Rgb};
6use itertools::Itertools;
7use ndarray::{ArcArray, Axis, Ix4};
8use palette::{LinSrgb, Srgb};
9
10use crate::cpu::ExecutionInfo;
11use crate::dtype::{DTensor, Tensor};
12use crate::graph::{Graph, Operation, Value};
13use crate::shape::Size;
14
15pub type Image = ImageBuffer<Rgb<u8>, Vec<u8>>;
16type Tensor4 = ArcArray<f32, Ix4>;
17
18const VERTICAL_PADDING: usize = 5;
19const HORIZONTAL_PADDING: usize = 5;
20
21#[derive(Debug)]
22pub struct VisTensor {
23 pub normalize: bool,
24 pub tensor: Tensor<f32>,
25}
26
27#[derive(Debug)]
28pub struct RenderTensor {
29 value: Value,
30 original: bool,
31 vis_tensor: VisTensor,
32}
33
34pub fn visualize_graph_activations(
35 graph: &Graph,
36 execution: &ExecutionInfo,
37 post_process_value: impl Fn(Value, &DTensor) -> Vec<VisTensor>,
38 max_images: Option<usize>,
39 show_variance: bool,
40 print_details: bool,
41) -> Vec<Image> {
42 let batch_size = execution.batch_size;
43 let image_count = max_images.map_or(batch_size, |max_images| max(max_images, batch_size));
44
45 if image_count == 0 {
47 return vec![];
48 }
49
50 let mut total_width = HORIZONTAL_PADDING;
51 let mut total_height = VERTICAL_PADDING;
52
53 let mut to_render = vec![];
54
55 for value in execution.values.values() {
56 let info = &graph[value.value];
57
58 if !should_show_value(graph, value.value) {
59 continue;
60 }
61
62 let is_intermediate_shape = info.shape.rank() > 0
64 && info.shape[0] == Size::BATCH
65 && info.shape.dims[1..].iter().all(|d| d.try_unwrap_fixed().is_some());
66 if !is_intermediate_shape {
67 println!("Skipping value with shape {:?}", info.shape);
68 continue;
69 }
70
71 let is_input = matches!(&info.operation, Operation::Input { .. });
72 let data = value
73 .tensor
74 .as_ref()
75 .expect("Intermediate values should have been kept for visualization");
76
77 if let DTensor::F32(data) = data {
78 let vis_tensor = VisTensor {
79 normalize: !is_input,
80 tensor: data.to_shared(),
81 };
82 to_render.push(RenderTensor {
83 value: value.value,
84 original: true,
85 vis_tensor,
86 });
87 }
88
89 for extra_vis_tensor in post_process_value(value.value, data) {
90 to_render.push(RenderTensor {
91 value: value.value,
92 original: false,
93 vis_tensor: extra_vis_tensor,
94 });
95 }
96 }
97
98 let mut all_details = vec![];
99 for render_tensor in to_render {
100 let RenderTensor {
101 value,
102 original,
103 vis_tensor,
104 } = render_tensor;
105 let VisTensor {
106 normalize,
107 tensor: data,
108 } = vis_tensor;
109 let size = data.len();
110
111 let data: Tensor4 = match data.ndim() {
112 1 => data.reshape((batch_size, 1, 1, 1)),
113 2 => data.reshape((batch_size, 1, 1, size / batch_size)),
114 3 => data.insert_axis(Axis(1)).into_dimensionality().unwrap(),
115 4 => data.into_dimensionality().unwrap(),
116 _ => {
117 println!("Skipping value with (picked) shape {:?}", data.dim());
118 continue;
119 }
120 };
121
122 let data = if matches!(data.dim(), (_, _, 1, 1)) {
123 data.reshape((batch_size, 1, 1, data.dim().1))
124 } else {
125 data
126 };
127
128 let (_, channels, height, width) = data.dim();
129
130 let view_width = channels * width + (channels - 1) * HORIZONTAL_PADDING;
131 let view_height = height;
132
133 if total_height != VERTICAL_PADDING {
134 total_height += VERTICAL_PADDING;
135 }
136 let start_y = total_height;
137 total_height += view_height;
138
139 total_width = max(total_width, HORIZONTAL_PADDING + view_width);
140
141 let details = Details {
142 value,
143 original,
144 start_y,
145 normalize,
146 data,
147 };
148 all_details.push(details)
149 }
150
151 total_width += HORIZONTAL_PADDING;
152 total_height += VERTICAL_PADDING;
153
154 let background = Srgb::from(LinSrgb::new(0.01, 0.01, 0.01));
155 let background = Rgb([background.red, background.green, background.blue]);
156
157 let mut images = (0..image_count)
158 .map(|_| ImageBuffer::from_pixel(total_width as u32, total_height as u32, background))
159 .collect_vec();
160
161 for details in all_details.iter() {
162 if print_details {
163 println!("{:?} {:?}", details, graph[details.value]);
164 }
165
166 let data = &details.data;
167 let (_, channels, height, width) = data.dim();
168
169 if data.iter().any(|x| !x.is_finite()) {
170 eprintln!("Warning: encountered non-finite value in {:?}", details);
171 }
172
173 let mean = data.mean().unwrap();
175 let std = data.std(1.0);
176 let data_norm = (data - mean) / std;
177
178 let std_ele = data_norm.std_axis(Axis(0), 1.0);
179 let std_ele_mean = std_ele.mean().unwrap();
180 let std_ele_std = std_ele.std(1.0);
181
182 for (image_i, image) in images.iter_mut().enumerate() {
183 for c in 0..channels {
184 for w in 0..width {
185 let x = HORIZONTAL_PADDING + c * (HORIZONTAL_PADDING + width) + w;
186 for h in 0..height {
187 let y = details.start_y + (height - 1 - h);
188
189 let s = (std_ele[(c, h, w)] - std_ele_mean) / std_ele_std;
190 let s_norm = ((s + 1.0) / 2.0).clamp(0.0, 1.0);
191
192 let gb = if details.normalize {
193 let f = data_norm[(image_i, c, h, w)];
194 let f_norm = ((f + 1.0) / 2.0).clamp(0.0, 1.0);
195 f_norm
196 } else {
197 data[(image_i, c, h, w)].clamp(0.0, 1.0)
198 };
199 let r = if show_variance { s_norm } else { gb };
200
201 let color = Srgb::from(LinSrgb::new(r, gb, gb));
202 let p = Rgb([color.red, color.green, color.blue]);
203 image.put_pixel(x as u32, y as u32, p);
204 }
205 }
206 }
207 }
208 }
209
210 images
211}
212
213fn should_show_value(graph: &Graph, value: Value) -> bool {
214 if graph.inputs().contains(&value) || graph.outputs().contains(&value) {
215 return true;
216 }
217
218 if is_effectively_constant(graph, value) {
219 return false;
220 }
221
222 let has_dummy_user = graph.values().any(|other| {
223 let other_operation = &graph[other].operation;
224
225 if other_operation.inputs().contains(&value) {
227 match other_operation {
228 Operation::Input { .. } | Operation::Constant { .. } => unreachable!(),
229 &Operation::View { input } => {
230 zip(&graph[input].shape.dims, &graph[other].shape.dims).all(|(l, r)| l == r)
232 }
233 Operation::Broadcast { .. }
234 | Operation::Permute { .. }
235 | Operation::Slice { .. }
236 | Operation::Flip { .. }
237 | Operation::Gather { .. }
238 | Operation::Concat { .. }
239 | Operation::Conv { .. }
240 | Operation::MatMul { .. }
241 | Operation::Softmax { .. }
242 | Operation::Layernorm { .. }
243 | Operation::Reduce { .. }
244 | Operation::Unary { .. } => false,
245 &Operation::Binary { left, right, op: _ } => graph[left].shape != graph[right].shape,
246 }
247 } else {
248 false
249 }
250 });
251
252 !has_dummy_user
253}
254
255fn is_effectively_constant(graph: &Graph, value: Value) -> bool {
256 let operation = &graph[value].operation;
257 match operation {
258 Operation::Input { .. } => false,
259 Operation::Constant { .. } => true,
260 Operation::View { .. }
261 | Operation::Broadcast { .. }
262 | Operation::Permute { .. }
263 | Operation::Slice { .. }
264 | Operation::Flip { .. }
265 | Operation::Gather { .. }
266 | Operation::Concat { .. }
267 | Operation::Conv { .. }
268 | Operation::MatMul { .. }
269 | Operation::Unary { .. }
270 | Operation::Binary { .. }
271 | Operation::Softmax { .. }
272 | Operation::Layernorm { .. }
273 | Operation::Reduce { .. } => operation.inputs().iter().all(|&v| is_effectively_constant(graph, v)),
274 }
275}
276
277impl VisTensor {
278 pub fn abs(tensor: Tensor<f32>) -> VisTensor {
279 VisTensor {
280 normalize: false,
281 tensor,
282 }
283 }
284
285 pub fn norm(tensor: Tensor<f32>) -> VisTensor {
286 VisTensor {
287 normalize: true,
288 tensor,
289 }
290 }
291}
292
293struct Details {
294 value: Value,
295 original: bool,
296 start_y: usize,
297
298 normalize: bool,
299 data: Tensor4,
300}
301
302impl Debug for Details {
303 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
304 f.debug_struct("Details")
305 .field("value", &self.value)
306 .field("original", &self.original)
307 .field("start_y", &self.start_y)
308 .field("shape", &self.data.dim())
309 .finish()
310 }
311}