Skip to main content

cnn/
cnn.rs

1use cuneus::compute::*;
2use cuneus::prelude::*;
3
4cuneus::uniform_params! {
5    struct CNNParams {
6    brush_size: f32,
7    input_resolution: f32,
8    clear_canvas: i32,
9    show_debug: i32,
10    feature_maps_1: f32,
11    feature_maps_2: f32,
12    num_classes: f32,
13    normalization_mean: f32,
14    normalization_std: f32,
15    show_frequencies: i32,
16    conv1_pool_size: f32,
17    conv2_pool_size: f32,
18    _padding1: f32,
19    _padding2: f32,
20    _padding3: f32,
21    _padding4: f32,
22    _padding5: f32,
23    _padding6: f32,
24    _pad_m1: f32,
25    _pad_m2: f32,
26    }
27}
28
29struct CNNDigitRecognizer {
30    base: RenderKit,
31    compute_shader: ComputeShader,
32    current_params: CNNParams,
33    first_frame: bool}
34
35impl CNNDigitRecognizer {}
36
37impl ShaderManager for CNNDigitRecognizer {
38    fn init(core: &Core) -> Self {
39        let base = RenderKit::new(core);
40
41        // Configure multi-pass CNN with 5 stages: canvas_update -> conv_layer1 -> conv_layer2 -> fully_connected -> main_image
42        let passes = vec![
43            PassDescription::new("canvas_update", &[]).with_workgroup_size([28, 28, 1]),
44            
45            PassDescription::new("conv_layer1", &["canvas_update"])
46                .with_workgroup_size([12, 12, 16]), // 16 Feature Maps
47            
48            PassDescription::new("conv_layer2", &["conv_layer1"])
49                .with_workgroup_size([4, 4, 32]),   // 32 Feature Maps
50            
51            PassDescription::new("fully_connected", &["conv_layer2"])
52                .with_workgroup_size([47, 1, 1]),   // 47 Classes
53            
54            PassDescription::new("main_image", &["fully_connected"]),
55        ];
56
57        let compute_shader = ComputeShaderBuilder::new()
58            .with_label("CNN Digit Recognizer")
59            .with_multi_pass(&passes)
60            .with_custom_uniforms::<CNNParams>()
61            .with_mouse()
62            .with_fonts()
63            .with_storage_buffer(StorageBufferSpec::new("canvas_data", (28 * 28 * 4) as u64))
64            .with_storage_buffer(StorageBufferSpec::new(
65                "conv1_data",
66                (12 * 12 * 16 * 4) as u64,
67            )) 
68            .with_storage_buffer(StorageBufferSpec::new(
69                "conv2_data", 
70                (4 * 4 * 32 * 4) as u64
71            )) 
72            .with_storage_buffer(StorageBufferSpec::new(
73                "fc_data", 
74                (47 * 4) as u64
75            ))
76            .build();
77
78        let compute_shader = cuneus::compute_shader!(core, "shaders/cnn.wgsl", compute_shader);
79
80
81        let current_params = CNNParams {
82            brush_size: 0.007,
83            input_resolution: 28.0,
84            clear_canvas: 0,
85            show_debug: 0,
86            feature_maps_1: 16.0,
87            feature_maps_2: 32.0,
88            num_classes: 47.0,
89            normalization_mean: 0.175,
90            normalization_std: 0.33,
91            show_frequencies: 0,
92            conv1_pool_size: 12.0,
93            conv2_pool_size: 4.0,
94            _padding1: 0.0,
95            _padding2: 0.0,
96            _padding3: 0.0,
97            _padding4: 0.0,
98            _padding5: 0.0,
99            _padding6: 0.0,
100            _pad_m1: 0.0,
101            _pad_m2: 0.0,
102        };
103
104        Self {
105            base,
106            compute_shader,
107            current_params,
108            first_frame: true}
109    }
110
111    fn update(&mut self, _core: &Core) {
112    }
113
114    fn resize(&mut self, core: &Core) {
115        self.compute_shader
116            .resize(core, core.size.width, core.size.height);
117    }
118
119    fn render(&mut self, core: &Core) -> Result<(), cuneus::SurfaceError> {
120        let mut frame = self.base.begin_frame(core)?;
121
122
123        let mut params = self.current_params;
124        let mut changed = self.first_frame; // Update params on first frame
125        let mut should_start_export = false;
126        let mut export_request = self.base.export_manager.get_ui_request();
127        let mut controls_request = self
128            .base
129            .controls
130            .get_ui_request(&self.base.start_time, &core.size, self.base.fps_tracker.fps());
131
132        let full_output = if self.base.key_handler.show_ui {
133            self.base.render_ui(core, |ctx| {
134                RenderKit::apply_default_style(ctx);
135
136                egui::Window::new("CNN chr Recognizer")
137                    .collapsible(true)
138                    .resizable(true)
139                    .default_width(280.0)
140                    .show(ctx, |ui| {
141                        ui.label("Draw a character in the canvas area and watch the CNN predict it!");
142                        ui.separator();
143                        ui.label("The CNN will predict the character using pre-trained weights");
144                        ui.separator();
145
146                        egui::CollapsingHeader::new("Brush")
147                            .default_open(true)
148                            .show(ui, |ui| {
149                                changed |= ui
150                                    .add(
151                                        egui::Slider::new(&mut params.brush_size, 0.001..=0.015)
152                                            .text("Brush Size"),
153                                    )
154                                    .changed();
155                                if ui.button("Clear Canvas").clicked() {
156                                    params.clear_canvas = 1;
157                                    changed = true;
158                                } else {
159                                    params.clear_canvas = 0;
160                                }
161                            });
162
163                        ui.separator();
164                        ShaderControls::render_controls_widget(ui, &mut controls_request);
165
166                        ui.separator();
167                        should_start_export =
168                            ExportManager::render_export_ui_widget(ui, &mut export_request);
169                    });
170            })
171        } else {
172            self.base.render_ui(core, |_ctx| {})
173        };
174
175        // Update mouse uniform for drawing interaction
176        self.compute_shader
177            .update_mouse_uniform(&self.base.mouse_tracker.uniform, &core.queue);
178
179        // Execute CNN pipeline
180        // Note: our backend automatically uses custom workgroup sizes from PassDescription
181        self.compute_shader.dispatch(&mut frame.encoder, core);
182
183        self.base.renderer.render_to_view(&mut frame.encoder, &frame.view, &self.compute_shader.get_output_texture().bind_group);
184
185        // Apply UI changes
186        self.base.apply_control_request(controls_request.clone());
187
188        self.base.export_manager.apply_ui_request(export_request);
189        if should_start_export {
190            self.base.export_manager.start_export();
191        }
192
193        if changed {
194            self.current_params = params;
195            self.compute_shader.set_custom_params(params, &core.queue);
196            self.first_frame = false;
197        }
198
199        self.base.end_frame(core, frame, full_output);
200
201        Ok(())
202    }
203
204    fn handle_input(&mut self, core: &Core, event: &WindowEvent) -> bool {
205        if self.base.default_handle_input(core, event) {
206            return true;
207        }
208        self.base.handle_mouse_input(core, event, false)
209    }
210}
211
212fn main() -> Result<(), Box<dyn std::error::Error>> {
213    let (app, event_loop) = ShaderApp::new("EMNIST", 800, 600);
214
215    app.run(event_loop, CNNDigitRecognizer::init)
216}