Skip to main content

cal_hardware/
projection_window.rs

1use anyhow::Result;
2use std::sync::Arc;
3use winit::window::Window;
4
5pub struct ProjectionWindow {
6    _window: Arc<Window>,
7    surface: wgpu::Surface<'static>,
8    device: wgpu::Device,
9    queue: wgpu::Queue,
10    _config: wgpu::SurfaceConfiguration,
11    _size: winit::dpi::PhysicalSize<u32>,
12
13    // Textures for each projection frame
14    _textures: Vec<wgpu::Texture>,
15    bind_groups: Vec<wgpu::BindGroup>,
16    render_pipeline: wgpu::RenderPipeline,
17}
18
19impl ProjectionWindow {
20    pub async fn new(
21        event_loop: &winit::event_loop::ActiveEventLoop,
22        width: u32,
23        height: u32,
24        monitor_index: Option<usize>,
25    ) -> Result<Self> {
26        let mut window_attributes = Window::default_attributes()
27            .with_title("CAL Projection Window")
28            .with_inner_size(winit::dpi::PhysicalSize::new(width, height));
29
30        if let Some(idx) = monitor_index {
31            let monitor = event_loop
32                .available_monitors()
33                .nth(idx)
34                .ok_or_else(|| anyhow::anyhow!("Monitor index {} not found", idx))?;
35            println!("Selecting monitor: {:?}", monitor.name());
36            window_attributes = window_attributes
37                .with_fullscreen(Some(winit::window::Fullscreen::Borderless(Some(monitor))));
38        }
39
40        let window = Arc::new(event_loop.create_window(window_attributes)?);
41
42        let instance = wgpu::Instance::default();
43        let surface = instance.create_surface(Arc::clone(&window))?;
44        let adapter = instance
45            .request_adapter(&wgpu::RequestAdapterOptions {
46                power_preference: wgpu::PowerPreference::HighPerformance,
47                force_fallback_adapter: false,
48                compatible_surface: Some(&surface),
49            })
50            .await
51            .ok_or_else(|| anyhow::anyhow!("Failed to find a suitable GPU adapter"))?;
52
53        let (device, queue) = adapter
54            .request_device(&wgpu::DeviceDescriptor::default(), None)
55            .await?;
56
57        let size = window.inner_size();
58        let config = surface
59            .get_default_config(&adapter, size.width, size.height)
60            .unwrap();
61        surface.configure(&device, &config);
62
63        // Simple shader for displaying a texture
64        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
65            label: Some("Projection Shader"),
66            source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
67        });
68
69        let texture_bind_group_layout =
70            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
71                entries: &[
72                    wgpu::BindGroupLayoutEntry {
73                        binding: 0,
74                        visibility: wgpu::ShaderStages::FRAGMENT,
75                        ty: wgpu::BindingType::Texture {
76                            multisampled: false,
77                            view_dimension: wgpu::TextureViewDimension::D2,
78                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
79                        },
80                        count: None,
81                    },
82                    wgpu::BindGroupLayoutEntry {
83                        binding: 1,
84                        visibility: wgpu::ShaderStages::FRAGMENT,
85                        ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::NonFiltering),
86                        count: None,
87                    },
88                ],
89                label: Some("texture_bind_group_layout"),
90            });
91
92        let render_pipeline_layout =
93            device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
94                label: Some("Render Pipeline Layout"),
95                bind_group_layouts: &[&texture_bind_group_layout],
96                push_constant_ranges: &[],
97            });
98
99        let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
100            label: Some("Render Pipeline"),
101            layout: Some(&render_pipeline_layout),
102            vertex: wgpu::VertexState {
103                module: &shader,
104                entry_point: Some("vs_main"),
105                buffers: &[],
106                compilation_options: Default::default(),
107            },
108            fragment: Some(wgpu::FragmentState {
109                module: &shader,
110                entry_point: Some("fs_main"),
111                targets: &[Some(wgpu::ColorTargetState {
112                    format: config.format,
113                    blend: Some(wgpu::BlendState::REPLACE),
114                    write_mask: wgpu::ColorWrites::ALL,
115                })],
116                compilation_options: Default::default(),
117            }),
118            primitive: wgpu::PrimitiveState {
119                topology: wgpu::PrimitiveTopology::TriangleList,
120                ..Default::default()
121            },
122            depth_stencil: None,
123            multisample: wgpu::MultisampleState::default(),
124            multiview: None,
125            cache: None,
126        });
127
128        Ok(Self {
129            _window: window,
130            surface,
131            device,
132            queue,
133            _config: config,
134            _size: size,
135            _textures: Vec::new(),
136            bind_groups: Vec::new(),
137            render_pipeline,
138        })
139    }
140
141    pub fn prepare_projections(&mut self, projections: &ndarray::Array3<f32>) {
142        let (nr, n_angles, nz) = projections.dim();
143
144        // Find global max value for normalization across all angles/frames
145        let mut global_max = 0.0f32;
146        for val in projections.iter() {
147            if *val > global_max {
148                global_max = *val;
149            }
150        }
151        println!("Global max projection intensity: {:.4}", global_max);
152
153        let sampler = self.device.create_sampler(&wgpu::SamplerDescriptor {
154            address_mode_u: wgpu::AddressMode::ClampToEdge,
155            address_mode_v: wgpu::AddressMode::ClampToEdge,
156            mag_filter: wgpu::FilterMode::Nearest,
157            min_filter: wgpu::FilterMode::Nearest,
158            ..Default::default()
159        });
160
161        for a_idx in 0..n_angles {
162            let texture_size = wgpu::Extent3d {
163                width: nr as u32,
164                height: nz as u32,
165                depth_or_array_layers: 1,
166            };
167
168            let texture = self.device.create_texture(&wgpu::TextureDescriptor {
169                size: texture_size,
170                mip_level_count: 1,
171                sample_count: 1,
172                dimension: wgpu::TextureDimension::D2,
173                format: wgpu::TextureFormat::R32Float,
174                usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
175                label: Some(&format!("Projection Texture {}", a_idx)),
176                view_formats: &[],
177            });
178
179            let mut data = vec![0.0f32; nr * nz];
180            for z in 0..nz {
181                for r in 0..nr {
182                    let val = projections[[r, a_idx, z]];
183                    data[z * nr + r] = if global_max > 0.0 {
184                        val / global_max
185                    } else {
186                        0.0
187                    };
188                }
189            }
190
191            self.queue.write_texture(
192                wgpu::TexelCopyTextureInfo {
193                    texture: &texture,
194                    mip_level: 0,
195                    origin: wgpu::Origin3d::ZERO,
196                    aspect: wgpu::TextureAspect::All,
197                },
198                bytemuck::cast_slice(&data),
199                wgpu::TexelCopyBufferLayout {
200                    offset: 0,
201                    bytes_per_row: Some(4 * nr as u32),
202                    rows_per_image: Some(nz as u32),
203                },
204                texture_size,
205            );
206
207            let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
208            let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
209                layout: &self.render_pipeline.get_bind_group_layout(0),
210                entries: &[
211                    wgpu::BindGroupEntry {
212                        binding: 0,
213                        resource: wgpu::BindingResource::TextureView(&view),
214                    },
215                    wgpu::BindGroupEntry {
216                        binding: 1,
217                        resource: wgpu::BindingResource::Sampler(&sampler),
218                    },
219                ],
220                label: None,
221            });
222
223            self._textures.push(texture);
224            self.bind_groups.push(bind_group);
225        }
226    }
227
228    pub fn render(&self, frame_idx: usize) -> Result<()> {
229        let output = self.surface.get_current_texture()?;
230        let view = output
231            .texture
232            .create_view(&wgpu::TextureViewDescriptor::default());
233        let mut encoder = self
234            .device
235            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
236                label: Some("Render Encoder"),
237            });
238
239        {
240            let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
241                label: Some("Render Pass"),
242                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
243                    view: &view,
244                    resolve_target: None,
245                    ops: wgpu::Operations {
246                        load: wgpu::LoadOp::Clear(wgpu::Color::BLACK),
247                        store: wgpu::StoreOp::Store,
248                    },
249                })],
250                depth_stencil_attachment: None,
251                occlusion_query_set: None,
252                timestamp_writes: None,
253            });
254
255            render_pass.set_pipeline(&self.render_pipeline);
256            render_pass.set_bind_group(0, &self.bind_groups[frame_idx], &[]);
257            render_pass.draw(0..3, 0..1); // Draw a full-screen triangle
258        }
259
260        self.queue.submit(std::iter::once(encoder.finish()));
261        output.present();
262
263        Ok(())
264    }
265
266    pub fn request_redraw(&self) {
267        self._window.request_redraw();
268    }
269}