1use cuneus::compute::*;
2use cuneus::prelude::*;
3
4cuneus::uniform_params! {
5 struct CNNParams {
6 brush_size: f32,
7 input_resolution: f32,
8 clear_canvas: i32,
9 show_debug: i32,
10 feature_maps_1: f32,
11 feature_maps_2: f32,
12 num_classes: f32,
13 normalization_mean: f32,
14 normalization_std: f32,
15 show_frequencies: i32,
16 conv1_pool_size: f32,
17 conv2_pool_size: f32,
18 _padding1: f32,
19 _padding2: f32,
20 _padding3: f32,
21 _padding4: f32,
22 _padding5: f32,
23 _padding6: f32,
24 _pad_m1: f32,
25 _pad_m2: f32,
26 }
27}
28
29struct CNNDigitRecognizer {
30 base: RenderKit,
31 compute_shader: ComputeShader,
32 current_params: CNNParams,
33 first_frame: bool}
34
35impl CNNDigitRecognizer {}
36
37impl ShaderManager for CNNDigitRecognizer {
38 fn init(core: &Core) -> Self {
39 let base = RenderKit::new(core);
40
41 let passes = vec![
43 PassDescription::new("canvas_update", &[]).with_workgroup_size([28, 28, 1]),
44
45 PassDescription::new("conv_layer1", &["canvas_update"])
46 .with_workgroup_size([12, 12, 16]), PassDescription::new("conv_layer2", &["conv_layer1"])
49 .with_workgroup_size([4, 4, 32]), PassDescription::new("fully_connected", &["conv_layer2"])
52 .with_workgroup_size([47, 1, 1]), PassDescription::new("main_image", &["fully_connected"]),
55 ];
56
57 let compute_shader = ComputeShaderBuilder::new()
58 .with_label("CNN Digit Recognizer")
59 .with_multi_pass(&passes)
60 .with_custom_uniforms::<CNNParams>()
61 .with_mouse()
62 .with_fonts()
63 .with_storage_buffer(StorageBufferSpec::new("canvas_data", (28 * 28 * 4) as u64))
64 .with_storage_buffer(StorageBufferSpec::new(
65 "conv1_data",
66 (12 * 12 * 16 * 4) as u64,
67 ))
68 .with_storage_buffer(StorageBufferSpec::new(
69 "conv2_data",
70 (4 * 4 * 32 * 4) as u64
71 ))
72 .with_storage_buffer(StorageBufferSpec::new(
73 "fc_data",
74 (47 * 4) as u64
75 ))
76 .build();
77
78 let compute_shader = cuneus::compute_shader!(core, "shaders/cnn.wgsl", compute_shader);
79
80
81 let current_params = CNNParams {
82 brush_size: 0.007,
83 input_resolution: 28.0,
84 clear_canvas: 0,
85 show_debug: 0,
86 feature_maps_1: 16.0,
87 feature_maps_2: 32.0,
88 num_classes: 47.0,
89 normalization_mean: 0.175,
90 normalization_std: 0.33,
91 show_frequencies: 0,
92 conv1_pool_size: 12.0,
93 conv2_pool_size: 4.0,
94 _padding1: 0.0,
95 _padding2: 0.0,
96 _padding3: 0.0,
97 _padding4: 0.0,
98 _padding5: 0.0,
99 _padding6: 0.0,
100 _pad_m1: 0.0,
101 _pad_m2: 0.0,
102 };
103
104 Self {
105 base,
106 compute_shader,
107 current_params,
108 first_frame: true}
109 }
110
111 fn update(&mut self, _core: &Core) {
112 }
113
114 fn resize(&mut self, core: &Core) {
115 self.compute_shader
116 .resize(core, core.size.width, core.size.height);
117 }
118
119 fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
120 let mut frame = self.base.begin_frame(core)?;
121
122
123 let mut params = self.current_params;
124 let mut changed = self.first_frame; let mut should_start_export = false;
126 let mut export_request = self.base.export_manager.get_ui_request();
127 let mut controls_request = self
128 .base
129 .controls
130 .get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
131
132 let full_output = if self.base.key_handler.show_ui {
133 self.base.render_ui(core, |ctx| {
134 RenderKit::apply_default_style(ctx);
135
136 egui::Window::new("CNN chr Recognizer")
137 .collapsible(true)
138 .resizable(true)
139 .default_width(280.0)
140 .show(ctx, |ui| {
141 ui.label("Draw a character in the canvas area and watch the CNN predict it!");
142 ui.separator();
143 ui.label("The CNN will predict the character using pre-trained weights");
144 ui.separator();
145
146 egui::CollapsingHeader::new("Brush")
147 .default_open(true)
148 .show(ui, |ui| {
149 changed |= ui
150 .add(
151 egui::Slider::new(&mut params.brush_size, 0.001..=0.015)
152 .text("Brush Size"),
153 )
154 .changed();
155 if ui.button("Clear Canvas").clicked() {
156 params.clear_canvas = 1;
157 changed = true;
158 } else {
159 params.clear_canvas = 0;
160 }
161 });
162
163 ui.separator();
164 ShaderControls::render_controls_widget(ui, &mut controls_request);
165
166 ui.separator();
167 should_start_export =
168 ExportManager::render_export_ui_widget(ui, &mut export_request);
169 });
170 })
171 } else {
172 self.base.render_ui(core, |_ctx| {})
173 };
174
175 self.compute_shader
177 .update_mouse_uniform(&self.base.mouse_tracker.uniform, &core.queue);
178
179 self.compute_shader.dispatch(&mut frame.encoder, core);
182
183 self.base.renderer.render_to_view(&mut frame.encoder, &frame.view, &self.compute_shader.get_output_texture().bind_group);
184
185 self.base.apply_control_request(controls_request.clone());
187
188 self.base.export_manager.apply_ui_request(export_request);
189 if should_start_export {
190 self.base.export_manager.start_export();
191 }
192
193 if changed {
194 self.current_params = params;
195 self.compute_shader.set_custom_params(params, &core.queue);
196 self.first_frame = false;
197 }
198
199 self.base.end_frame(core, frame, full_output);
200
201 Ok(())
202 }
203
204 fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
205 if self.base.default_handle_input(core, event) {
206 return true;
207 }
208 self.base.handle_mouse_input(core, event, false)
209 }
210}
211
212fn main() -> Result<(), Box<dyn std::error::Error>> {
213 let (app, event_loop) = ShaderApp::new("EMNIST", 800, 600);
214
215 app.run(event_loop, CNNDigitRecognizer::init)
216}