1use fontdue::layout::LayoutSettings;
2use minidx_core::Dtype;
3use raqote::*;
4
5mod chart;
6use chart::LineChart;
7
8mod network_traits;
9pub use network_traits::VisualizableNetwork;
10
11mod font;
12pub use font::VisFont;
13
14pub mod anim;
15
16pub mod prelude {
17 pub use crate::anim;
18 pub use crate::ParamVisOpts;
19 pub use crate::VisualizableNetwork;
20}
21
22#[derive(Debug, Clone)]
24pub struct ParamBox {
25 w: f32,
26 h: f32,
27 font_size: f32,
28}
29
30impl Default for ParamBox {
31 fn default() -> Self {
32 Self {
33 w: 40.0,
34 h: 40.0,
35 font_size: 17.0,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default)]
42pub enum ParamScale {
43 #[default]
44 None,
45 StdDev {
46 mul: f32,
47 },
48}
49
50#[derive(Debug, Clone)]
58pub struct ParamVisOpts {
59 offset: (f32, f32),
60 module_padding: (f32, f32),
61 cell: ParamBox,
62 font: VisFont,
63}
64
65impl Default for ParamVisOpts {
66 fn default() -> Self {
67 let font = VisFont::default_font().unwrap();
68
69 Self {
70 offset: (2.0, 2.0),
71 module_padding: (2.0, 6.0),
72 cell: Default::default(),
73 font,
74 }
75 }
76}
77
78impl ParamVisOpts {
79 pub fn update_cursor(&mut self, offset: (f32, f32)) -> &mut Self {
82 self.offset.0 += offset.0;
83 self.offset.1 += offset.1 + self.module_padding.1;
84
85 self
86 }
87
88 pub fn small() -> Self {
90 Self {
91 cell: ParamBox {
92 w: 20.0,
93 h: 20.0,
94 font_size: 9.0,
95 },
96 ..Default::default()
97 }
98 }
99}
100
101trait PaintParams<P> {
103 type Concrete: Sized;
104
105 fn paint_params(&mut self, params: &P, opts: &mut ParamVisOpts);
106 fn layout_bounds(&self, opts: &ParamVisOpts) -> (f32, f32);
107}
108
109impl PaintParams<()> for DrawTarget {
110 type Concrete = ();
111 fn layout_bounds(&self, opts: &ParamVisOpts) -> (f32, f32) {
112 (opts.module_padding.0, opts.module_padding.1)
113 }
114
115 fn paint_params(&mut self, _params: &(), _opts: &mut ParamVisOpts) {}
116}
117
118impl<E: Dtype, const I: usize, const O: usize> PaintParams<[[E; I]; O]> for DrawTarget {
119 type Concrete = [[E; I]; O];
120 fn layout_bounds(&self, opts: &ParamVisOpts) -> (f32, f32) {
121 (opts.cell.w * I as f32, opts.cell.h * O as f32)
122 }
123
124 fn paint_params(&mut self, params: &[[E; I]; O], opts: &mut ParamVisOpts) {
125 let scale = 1.0;
126
127 for (j, params) in params.iter().enumerate() {
128 for (i, v) in params.iter().enumerate() {
129 let tl = (
130 opts.offset.0 + opts.cell.w * i as f32,
131 opts.offset.1 + opts.cell.h * j as f32,
132 );
133
134 let mut pb = PathBuilder::new();
136 pb.move_to(tl.0, tl.1);
137 pb.line_to(tl.0 + opts.cell.w, tl.1);
138 pb.line_to(tl.0 + opts.cell.w, tl.1 + opts.cell.h);
139 pb.line_to(tl.0, tl.1 + opts.cell.h);
140 pb.line_to(tl.0, tl.1);
141 let p = pb.finish();
142
143 let v = scale * v.to_f32().unwrap();
145 self.fill(
146 &p,
147 &Source::Solid(SolidSource::from_unpremultiplied_argb(
148 0xFF,
149 ((-v).tanh().max(0.0) * 130.0) as u8 + 48,
150 (v.tanh().max(0.0) * 120.0) as u8 + 48,
151 48,
152 )),
153 &DrawOptions::new(),
154 );
155 self.stroke(
157 &p,
158 &Source::Solid(SolidSource::from_unpremultiplied_argb(0xFF, 0, 0, 0)),
159 &StrokeStyle {
160 width: 1.0,
161 ..StrokeStyle::default()
162 },
163 &DrawOptions::new(),
164 );
165
166 let v_abs = v.abs();
168 let mut s = if v_abs >= 10.0 {
169 format!("{:.0}", v_abs)
170 } else {
171 format!("{:.1}", v_abs)
172 };
173 s.truncate(3);
174
175 opts.font.raster(
176 &LayoutSettings {
177 x: tl.0,
178 y: tl.1 + 1.0,
179 max_width: Some(opts.cell.w),
180 max_height: Some(opts.cell.h - 2.0),
181 horizontal_align: fontdue::layout::HorizontalAlign::Center,
182 vertical_align: fontdue::layout::VerticalAlign::Middle,
183 ..LayoutSettings::default()
184 },
185 s.as_str(),
186 opts.cell.font_size,
187 (201, 201, 201),
188 self,
189 );
190 }
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_paint_params() {
201 let params = &[
202 [
203 -10.0, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1,
204 ],
205 [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 10.0],
206 ];
207 let mut dt = DrawTarget::new(460, 200);
208 dt.clear(SolidSource::from_unpremultiplied_argb(
209 0xff, 0xcf, 0xcf, 0xcf,
210 ));
211
212 dt.paint_params(params, &mut ParamVisOpts::small());
213
214 }
216
217 #[test]
218 fn test_visualize() {
219 use minidx_core::layers as l;
220 let network = (
221 (
222 l::Dense::<f32, 2, 3>::default(),
223 l::Bias1d::<f32, 3>::default(),
224 ),
225 l::Activation::<f32>::default(),
226 l::Dense::<f32, 3, 1>::default(),
227 );
228 let mut dt = DrawTarget::new(460, 500);
229 dt.clear(SolidSource::from_unpremultiplied_argb(
230 0xff, 0xcf, 0xcf, 0xcf,
231 ));
232
233 let params = ParamVisOpts::default();
234
235 use VisualizableNetwork;
236 network.visualize(&mut dt, &mut params.clone());
237 }
239}