minidx_vis/
lib.rs

1use fontdue::layout::LayoutSettings;
2use minidx_core::Dtype;
3use raqote::*;
4
5mod chart;
6use chart::LineChart;
7
8mod network_traits;
9pub use network_traits::VisualizableNetwork;
10
11mod font;
12pub use font::VisFont;
13
14pub mod anim;
15
16pub mod prelude {
17    pub use crate::anim;
18    pub use crate::ParamVisOpts;
19    pub use crate::VisualizableNetwork;
20}
21
22/// Describes the sizing of the cell for a single parameter.
23#[derive(Debug, Clone)]
24pub struct ParamBox {
25    w: f32,
26    h: f32,
27    font_size: f32,
28}
29
30impl Default for ParamBox {
31    fn default() -> Self {
32        Self {
33            w: 40.0,
34            h: 40.0,
35            font_size: 17.0,
36        }
37    }
38}
39
40/// How to scale the representation of parameters relative to each other.
41#[derive(Debug, Clone, Default)]
42pub enum ParamScale {
43    #[default]
44    None,
45    StdDev {
46        mul: f32,
47    },
48}
49
50// impl Default for ParamScale {
51//     fn default() -> Self {
52//         Self::StdDev { mul: 1.2 }
53//     }
54// }
55
56/// Options for rendering a set of parameters.
57#[derive(Debug, Clone)]
58pub struct ParamVisOpts {
59    offset: (f32, f32),
60    module_padding: (f32, f32),
61    cell: ParamBox,
62    font: VisFont,
63}
64
65impl Default for ParamVisOpts {
66    fn default() -> Self {
67        let font = VisFont::default_font().unwrap();
68
69        Self {
70            offset: (2.0, 2.0),
71            module_padding: (2.0, 6.0),
72            cell: Default::default(),
73            font,
74        }
75    }
76}
77
78impl ParamVisOpts {
79    /// Returns a new [ParamVisOpts] with the offset updated for laying out
80    /// the next module.
81    pub fn update_cursor(&mut self, offset: (f32, f32)) -> &mut Self {
82        self.offset.0 += offset.0;
83        self.offset.1 += offset.1 + self.module_padding.1;
84
85        self
86    }
87
88    /// A small-cell variant.
89    pub fn small() -> Self {
90        Self {
91            cell: ParamBox {
92                w: 20.0,
93                h: 20.0,
94                font_size: 9.0,
95            },
96            ..Default::default()
97        }
98    }
99}
100
101/// Implements visual rendering of a set of parameters.
102trait PaintParams<P> {
103    type Concrete: Sized;
104
105    fn paint_params(&mut self, params: &P, opts: &mut ParamVisOpts);
106    fn layout_bounds(&self, opts: &ParamVisOpts) -> (f32, f32);
107}
108
109impl PaintParams<()> for DrawTarget {
110    type Concrete = ();
111    fn layout_bounds(&self, opts: &ParamVisOpts) -> (f32, f32) {
112        (opts.module_padding.0, opts.module_padding.1)
113    }
114
115    fn paint_params(&mut self, _params: &(), _opts: &mut ParamVisOpts) {}
116}
117
118impl<E: Dtype, const I: usize, const O: usize> PaintParams<[[E; I]; O]> for DrawTarget {
119    type Concrete = [[E; I]; O];
120    fn layout_bounds(&self, opts: &ParamVisOpts) -> (f32, f32) {
121        (opts.cell.w * I as f32, opts.cell.h * O as f32)
122    }
123
124    fn paint_params(&mut self, params: &[[E; I]; O], opts: &mut ParamVisOpts) {
125        let scale = 1.0;
126
127        for (j, params) in params.iter().enumerate() {
128            for (i, v) in params.iter().enumerate() {
129                let tl = (
130                    opts.offset.0 + opts.cell.w * i as f32,
131                    opts.offset.1 + opts.cell.h * j as f32,
132                );
133
134                // Make box
135                let mut pb = PathBuilder::new();
136                pb.move_to(tl.0, tl.1);
137                pb.line_to(tl.0 + opts.cell.w, tl.1);
138                pb.line_to(tl.0 + opts.cell.w, tl.1 + opts.cell.h);
139                pb.line_to(tl.0, tl.1 + opts.cell.h);
140                pb.line_to(tl.0, tl.1);
141                let p = pb.finish();
142
143                // Paint red => grey => green background based on parameter
144                let v = scale * v.to_f32().unwrap();
145                self.fill(
146                    &p,
147                    &Source::Solid(SolidSource::from_unpremultiplied_argb(
148                        0xFF,
149                        ((-v).tanh().max(0.0) * 130.0) as u8 + 48,
150                        (v.tanh().max(0.0) * 120.0) as u8 + 48,
151                        48,
152                    )),
153                    &DrawOptions::new(),
154                );
155                // Paint grid cell boundary
156                self.stroke(
157                    &p,
158                    &Source::Solid(SolidSource::from_unpremultiplied_argb(0xFF, 0, 0, 0)),
159                    &StrokeStyle {
160                        width: 1.0,
161                        ..StrokeStyle::default()
162                    },
163                    &DrawOptions::new(),
164                );
165
166                // Generate the text
167                let v_abs = v.abs();
168                let mut s = if v_abs >= 10.0 {
169                    format!("{:.0}", v_abs)
170                } else {
171                    format!("{:.1}", v_abs)
172                };
173                s.truncate(3);
174
175                opts.font.raster(
176                    &LayoutSettings {
177                        x: tl.0,
178                        y: tl.1 + 1.0,
179                        max_width: Some(opts.cell.w),
180                        max_height: Some(opts.cell.h - 2.0),
181                        horizontal_align: fontdue::layout::HorizontalAlign::Center,
182                        vertical_align: fontdue::layout::VerticalAlign::Middle,
183                        ..LayoutSettings::default()
184                    },
185                    s.as_str(),
186                    opts.cell.font_size,
187                    (201, 201, 201),
188                    self,
189                );
190            }
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_paint_params() {
201        let params = &[
202            [
203                -10.0, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1,
204            ],
205            [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 10.0],
206        ];
207        let mut dt = DrawTarget::new(460, 200);
208        dt.clear(SolidSource::from_unpremultiplied_argb(
209            0xff, 0xcf, 0xcf, 0xcf,
210        ));
211
212        dt.paint_params(params, &mut ParamVisOpts::small());
213
214        // dt.write_png("/tmp/ye.png").expect("write failed");
215    }
216
217    #[test]
218    fn test_visualize() {
219        use minidx_core::layers as l;
220        let network = (
221            (
222                l::Dense::<f32, 2, 3>::default(),
223                l::Bias1d::<f32, 3>::default(),
224            ),
225            l::Activation::<f32>::default(),
226            l::Dense::<f32, 3, 1>::default(),
227        );
228        let mut dt = DrawTarget::new(460, 500);
229        dt.clear(SolidSource::from_unpremultiplied_argb(
230            0xff, 0xcf, 0xcf, 0xcf,
231        ));
232
233        let params = ParamVisOpts::default();
234
235        use VisualizableNetwork;
236        network.visualize(&mut dt, &mut params.clone());
237        // dt.write_png("/tmp/ye.png").expect("write failed");
238    }
239}