kn_graph/
visualize.rs

1use std::cmp::max;
2use std::fmt::{Debug, Formatter};
3use std::iter::zip;
4
5use image::{ImageBuffer, Rgb};
6use itertools::Itertools;
7use ndarray::{ArcArray, Axis, Ix4};
8use palette::{LinSrgb, Srgb};
9
10use crate::cpu::ExecutionInfo;
11use crate::dtype::{DTensor, Tensor};
12use crate::graph::{Graph, Operation, Value};
13use crate::shape::Size;
14
15pub type Image = ImageBuffer<Rgb<u8>, Vec<u8>>;
16type Tensor4 = ArcArray<f32, Ix4>;
17
18const VERTICAL_PADDING: usize = 5;
19const HORIZONTAL_PADDING: usize = 5;
20
21#[derive(Debug)]
22pub struct VisTensor {
23    pub normalize: bool,
24    pub tensor: Tensor<f32>,
25}
26
27#[derive(Debug)]
28pub struct RenderTensor {
29    value: Value,
30    original: bool,
31    vis_tensor: VisTensor,
32}
33
34pub fn visualize_graph_activations(
35    graph: &Graph,
36    execution: &ExecutionInfo,
37    post_process_value: impl Fn(Value, &DTensor) -> Vec<VisTensor>,
38    max_images: Option<usize>,
39    show_variance: bool,
40    print_details: bool,
41) -> Vec<Image> {
42    let batch_size = execution.batch_size;
43    let image_count = max_images.map_or(batch_size, |max_images| max(max_images, batch_size));
44
45    // prevent divide by zero issues later
46    if image_count == 0 {
47        return vec![];
48    }
49
50    let mut total_width = HORIZONTAL_PADDING;
51    let mut total_height = VERTICAL_PADDING;
52
53    let mut to_render = vec![];
54
55    for value in execution.values.values() {
56        let info = &graph[value.value];
57
58        if !should_show_value(graph, value.value) {
59            continue;
60        }
61
62        // check whether this is the typical intermediate shape: [B, fixed*]
63        let is_intermediate_shape = info.shape.rank() > 0
64            && info.shape[0] == Size::BATCH
65            && info.shape.dims[1..].iter().all(|d| d.try_unwrap_fixed().is_some());
66        if !is_intermediate_shape {
67            println!("Skipping value with shape {:?}", info.shape);
68            continue;
69        }
70
71        let is_input = matches!(&info.operation, Operation::Input { .. });
72        let data = value
73            .tensor
74            .as_ref()
75            .expect("Intermediate values should have been kept for visualization");
76
77        if let DTensor::F32(data) = data {
78            let vis_tensor = VisTensor {
79                normalize: !is_input,
80                tensor: data.to_shared(),
81            };
82            to_render.push(RenderTensor {
83                value: value.value,
84                original: true,
85                vis_tensor,
86            });
87        }
88
89        for extra_vis_tensor in post_process_value(value.value, data) {
90            to_render.push(RenderTensor {
91                value: value.value,
92                original: false,
93                vis_tensor: extra_vis_tensor,
94            });
95        }
96    }
97
98    let mut all_details = vec![];
99    for render_tensor in to_render {
100        let RenderTensor {
101            value,
102            original,
103            vis_tensor,
104        } = render_tensor;
105        let VisTensor {
106            normalize,
107            tensor: data,
108        } = vis_tensor;
109        let size = data.len();
110
111        let data: Tensor4 = match data.ndim() {
112            1 => data.reshape((batch_size, 1, 1, 1)),
113            2 => data.reshape((batch_size, 1, 1, size / batch_size)),
114            3 => data.insert_axis(Axis(1)).into_dimensionality().unwrap(),
115            4 => data.into_dimensionality().unwrap(),
116            _ => {
117                println!("Skipping value with (picked) shape {:?}", data.dim());
118                continue;
119            }
120        };
121
122        let data = if matches!(data.dim(), (_, _, 1, 1)) {
123            data.reshape((batch_size, 1, 1, data.dim().1))
124        } else {
125            data
126        };
127
128        let (_, channels, height, width) = data.dim();
129
130        let view_width = channels * width + (channels - 1) * HORIZONTAL_PADDING;
131        let view_height = height;
132
133        if total_height != VERTICAL_PADDING {
134            total_height += VERTICAL_PADDING;
135        }
136        let start_y = total_height;
137        total_height += view_height;
138
139        total_width = max(total_width, HORIZONTAL_PADDING + view_width);
140
141        let details = Details {
142            value,
143            original,
144            start_y,
145            normalize,
146            data,
147        };
148        all_details.push(details)
149    }
150
151    total_width += HORIZONTAL_PADDING;
152    total_height += VERTICAL_PADDING;
153
154    let background = Srgb::from(LinSrgb::new(0.01, 0.01, 0.01));
155    let background = Rgb([background.red, background.green, background.blue]);
156
157    let mut images = (0..image_count)
158        .map(|_| ImageBuffer::from_pixel(total_width as u32, total_height as u32, background))
159        .collect_vec();
160
161    for details in all_details.iter() {
162        if print_details {
163            println!("{:?} {:?}", details, graph[details.value]);
164        }
165
166        let data = &details.data;
167        let (_, channels, height, width) = data.dim();
168
169        if data.iter().any(|x| !x.is_finite()) {
170            eprintln!("Warning: encountered non-finite value in {:?}", details);
171        }
172
173        // TODO it's still not clear what the best way to normalize/scale/clamp/represent this stuff is
174        let mean = data.mean().unwrap();
175        let std = data.std(1.0);
176        let data_norm = (data - mean) / std;
177
178        let std_ele = data_norm.std_axis(Axis(0), 1.0);
179        let std_ele_mean = std_ele.mean().unwrap();
180        let std_ele_std = std_ele.std(1.0);
181
182        for (image_i, image) in images.iter_mut().enumerate() {
183            for c in 0..channels {
184                for w in 0..width {
185                    let x = HORIZONTAL_PADDING + c * (HORIZONTAL_PADDING + width) + w;
186                    for h in 0..height {
187                        let y = details.start_y + (height - 1 - h);
188
189                        let s = (std_ele[(c, h, w)] - std_ele_mean) / std_ele_std;
190                        let s_norm = ((s + 1.0) / 2.0).clamp(0.0, 1.0);
191
192                        let gb = if details.normalize {
193                            let f = data_norm[(image_i, c, h, w)];
194                            let f_norm = ((f + 1.0) / 2.0).clamp(0.0, 1.0);
195                            f_norm
196                        } else {
197                            data[(image_i, c, h, w)].clamp(0.0, 1.0)
198                        };
199                        let r = if show_variance { s_norm } else { gb };
200
201                        let color = Srgb::from(LinSrgb::new(r, gb, gb));
202                        let p = Rgb([color.red, color.green, color.blue]);
203                        image.put_pixel(x as u32, y as u32, p);
204                    }
205                }
206            }
207        }
208    }
209
210    images
211}
212
213fn should_show_value(graph: &Graph, value: Value) -> bool {
214    if graph.inputs().contains(&value) || graph.outputs().contains(&value) {
215        return true;
216    }
217
218    if is_effectively_constant(graph, value) {
219        return false;
220    }
221
222    let has_dummy_user = graph.values().any(|other| {
223        let other_operation = &graph[other].operation;
224
225        // TODO what are we even calculating here? mostly questionable heuristics?
226        if other_operation.inputs().contains(&value) {
227            match other_operation {
228                Operation::Input { .. } | Operation::Constant { .. } => unreachable!(),
229                &Operation::View { input } => {
230                    // check if all commons dims at the start match, which implies the only different is trailing 1s
231                    zip(&graph[input].shape.dims, &graph[other].shape.dims).all(|(l, r)| l == r)
232                }
233                Operation::Broadcast { .. }
234                | Operation::Permute { .. }
235                | Operation::Slice { .. }
236                | Operation::Flip { .. }
237                | Operation::Gather { .. }
238                | Operation::Concat { .. }
239                | Operation::Conv { .. }
240                | Operation::MatMul { .. }
241                | Operation::Softmax { .. }
242                | Operation::Layernorm { .. }
243                | Operation::Reduce { .. }
244                | Operation::Unary { .. } => false,
245                &Operation::Binary { left, right, op: _ } => graph[left].shape != graph[right].shape,
246            }
247        } else {
248            false
249        }
250    });
251
252    !has_dummy_user
253}
254
255fn is_effectively_constant(graph: &Graph, value: Value) -> bool {
256    let operation = &graph[value].operation;
257    match operation {
258        Operation::Input { .. } => false,
259        Operation::Constant { .. } => true,
260        Operation::View { .. }
261        | Operation::Broadcast { .. }
262        | Operation::Permute { .. }
263        | Operation::Slice { .. }
264        | Operation::Flip { .. }
265        | Operation::Gather { .. }
266        | Operation::Concat { .. }
267        | Operation::Conv { .. }
268        | Operation::MatMul { .. }
269        | Operation::Unary { .. }
270        | Operation::Binary { .. }
271        | Operation::Softmax { .. }
272        | Operation::Layernorm { .. }
273        | Operation::Reduce { .. } => operation.inputs().iter().all(|&v| is_effectively_constant(graph, v)),
274    }
275}
276
277impl VisTensor {
278    pub fn abs(tensor: Tensor<f32>) -> VisTensor {
279        VisTensor {
280            normalize: false,
281            tensor,
282        }
283    }
284
285    pub fn norm(tensor: Tensor<f32>) -> VisTensor {
286        VisTensor {
287            normalize: true,
288            tensor,
289        }
290    }
291}
292
293struct Details {
294    value: Value,
295    original: bool,
296    start_y: usize,
297
298    normalize: bool,
299    data: Tensor4,
300}
301
302impl Debug for Details {
303    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
304        f.debug_struct("Details")
305            .field("value", &self.value)
306            .field("original", &self.original)
307            .field("start_y", &self.start_y)
308            .field("shape", &self.data.dim())
309            .finish()
310    }
311}