Skip to main content

threecrate_gpu/
tsdf.rs

1use crate::device::GpuContext;
2use threecrate_core::{PointCloud, ColoredPoint3f, Error, Result};
3use nalgebra::{Matrix4, Point3};
4use bytemuck::{Pod, Zeroable};
5use wgpu::util::DeviceExt;
6
7/// TSDF voxel data for GPU processing
8#[repr(C)]
9#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
10#[repr(align(16))]  // Ensure 16-byte alignment for GPU
11pub struct TsdfVoxel {
12    pub tsdf_value: f32,
13    pub weight: f32,
14    pub color_r: u32,
15    pub color_g: u32,
16    pub color_b: u32,
17    pub _padding1: u32,
18    pub _padding2: u32,
19    pub _padding3: u32,
20}
21
22/// TSDF volume parameters
23#[derive(Debug, Clone)]
24pub struct TsdfVolume {
25    pub voxel_size: f32,
26    pub truncation_distance: f32,
27    pub resolution: [u32; 3], // [width, height, depth]
28    pub origin: Point3<f32>,
29}
30
31/// Represents a TSDF volume stored on the GPU.
32pub struct TsdfVolumeGpu {
33    pub volume: TsdfVolume,
34    pub voxel_buffer: wgpu::Buffer,
35}
36
37/// Camera intrinsic parameters
38#[repr(C)]
39#[derive(Copy, Clone, Pod, Zeroable)]
40#[repr(align(16))]  // Ensure 16-byte alignment for GPU
41pub struct CameraIntrinsics {
42    pub fx: f32,
43    pub fy: f32,
44    pub cx: f32,
45    pub cy: f32,
46    pub width: u32,
47    pub height: u32,
48    pub depth_scale: f32,
49    pub _padding: f32,
50}
51
52/// TSDF integration parameters
53#[repr(C)]
54#[derive(Copy, Clone, Pod, Zeroable)]
55#[repr(align(16))]
56pub struct TsdfParams {
57    pub voxel_size: f32,
58    pub truncation_distance: f32,
59    pub max_weight: f32,
60    pub iso_value: f32,
61    pub resolution: [u32; 3],
62    pub _padding2: u32,
63    pub origin: [f32; 3],
64    pub _padding3: f32,
65}
66
67#[repr(C)]
68#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
69#[repr(align(16))]  // Ensure 16-byte alignment for GPU
70pub struct GpuPoint3f {
71    pub x: f32,
72    pub y: f32,
73    pub z: f32,
74    pub r: u32,
75    pub g: u32,
76    pub b: u32,
77    pub _padding1: u32,
78    pub _padding2: u32,
79}
80
81impl GpuContext {
82    /// Integrate depth image into TSDF volume
83    pub async fn tsdf_integrate(
84        &self,
85        volume: &mut TsdfVolume,
86        depth_image: &[f32],
87        color_image: Option<&[u8]>, // RGB color data
88        camera_pose: &Matrix4<f32>,
89        intrinsics: &CameraIntrinsics,
90    ) -> Result<Vec<TsdfVoxel>> {
91        let total_voxels = (volume.resolution[0] * volume.resolution[1] * volume.resolution[2]) as usize;
92        
93        // Create buffers
94        let depth_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
95            label: Some("TSDF Depth Buffer"),
96            contents: bytemuck::cast_slice(depth_image),
97            usage: wgpu::BufferUsages::STORAGE,
98        });
99
100        let color_buffer = if let Some(color_data) = color_image {
101            // Convert RGB u8 data to packed u32 RGB values
102            let mut packed_colors = Vec::with_capacity(color_data.len() / 3);
103            for chunk in color_data.chunks_exact(3) {
104                let r = chunk[0] as u32;
105                let g = chunk[1] as u32;
106                let b = chunk[2] as u32;
107                let packed = (r << 16) | (g << 8) | b;
108                packed_colors.push(packed);
109            }
110            
111            self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
112                label: Some("TSDF Color Buffer"),
113                contents: bytemuck::cast_slice(&packed_colors),
114                usage: wgpu::BufferUsages::STORAGE,
115            })
116        } else {
117            // Create empty buffer if no color data
118            self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
119                label: Some("TSDF Empty Color Buffer"),
120                contents: bytemuck::cast_slice(&[0u32; 4]), // Small dummy buffer
121                usage: wgpu::BufferUsages::STORAGE,
122            })
123        };
124
125        // Initialize TSDF volume if needed
126        let initial_voxels = vec![TsdfVoxel {
127            tsdf_value: 1.0,
128            weight: 0.0,
129            color_r: 0,
130            color_g: 0,
131            color_b: 0,
132            _padding1: 0,
133            _padding2: 0,
134            _padding3: 0,
135        }; total_voxels];
136
137        let voxel_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
138            label: Some("TSDF Voxel Buffer"),
139            contents: bytemuck::cast_slice(&initial_voxels),
140            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
141        });
142
143        // Convert camera transform to world-to-camera matrix (inverse of camera pose)
144        let world_to_camera = camera_pose.try_inverse()
145            .ok_or_else(|| Error::Gpu("Failed to invert camera pose matrix".into()))?;
146        
147        let mut camera_transform = [[0.0f32; 4]; 4];
148        for i in 0..4 {
149            for j in 0..4 {
150                camera_transform[i][j] = world_to_camera[(i, j)];
151            }
152        }
153
154        let transform_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
155            label: Some("TSDF Transform Buffer"),
156            contents: bytemuck::cast_slice(&[camera_transform]),
157            usage: wgpu::BufferUsages::UNIFORM,
158        });
159
160        let intrinsics_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
161            label: Some("TSDF Intrinsics Buffer"),
162            contents: bytemuck::bytes_of(intrinsics),
163            usage: wgpu::BufferUsages::UNIFORM,
164        });
165
166        let params = TsdfParams {
167            voxel_size: volume.voxel_size,
168            truncation_distance: volume.truncation_distance,
169            max_weight: 100.0,
170            iso_value: 0.0,
171            resolution: volume.resolution,
172            _padding2: 0,
173            origin: [volume.origin.x, volume.origin.y, volume.origin.z],
174            _padding3: 0.0,
175        };
176
177        let params_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
178            label: Some("TSDF Params Buffer"),
179            contents: bytemuck::bytes_of(&params),
180            usage: wgpu::BufferUsages::UNIFORM,
181        });
182
183        // Create compute pipeline
184        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
185            label: Some("TSDF Integration Shader"),
186            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/tsdf_integration.wgsl").into()),
187        });
188
189        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
190            label: Some("TSDF Integration Pipeline"),
191            layout: None,
192            module: &shader,
193            entry_point: Some("main"),
194            compilation_options: wgpu::PipelineCompilationOptions::default(),
195            cache: None,
196        });
197
198        // Create bind group
199        let bind_group_entries = vec![
200            wgpu::BindGroupEntry {
201                binding: 0,
202                resource: voxel_buffer.as_entire_binding(),
203            },
204            wgpu::BindGroupEntry {
205                binding: 1,
206                resource: depth_buffer.as_entire_binding(),
207            },
208            wgpu::BindGroupEntry {
209                binding: 2,
210                resource: transform_buffer.as_entire_binding(),
211            },
212            wgpu::BindGroupEntry {
213                binding: 3,
214                resource: intrinsics_buffer.as_entire_binding(),
215            },
216            wgpu::BindGroupEntry {
217                binding: 4,
218                resource: params_buffer.as_entire_binding(),
219            },
220            wgpu::BindGroupEntry {
221                binding: 5,
222                resource: color_buffer.as_entire_binding(),
223            },
224        ];
225
226        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
227            label: Some("TSDF Integration Bind Group"),
228            layout: &pipeline.get_bind_group_layout(0),
229            entries: &bind_group_entries,
230        });
231
232        // Dispatch compute shader
233        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
234            label: Some("TSDF Integration Encoder"),
235        });
236
237        {
238            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
239                label: Some("TSDF Integration Pass"),
240                timestamp_writes: None,
241            });
242
243            compute_pass.set_pipeline(&pipeline);
244            compute_pass.set_bind_group(0, &bind_group, &[]);
245            
246            // Dispatch with 4x4x4 workgroups
247            let workgroup_size = 4;
248            let dispatch_x = (volume.resolution[0] + workgroup_size - 1) / workgroup_size;
249            let dispatch_y = (volume.resolution[1] + workgroup_size - 1) / workgroup_size;
250            let dispatch_z = (volume.resolution[2] + workgroup_size - 1) / workgroup_size;
251            
252            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, dispatch_z);
253        }
254
255        // Read back results
256        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
257            label: Some("TSDF Staging Buffer"),
258            size: (total_voxels * std::mem::size_of::<TsdfVoxel>()) as u64,
259            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
260            mapped_at_creation: false,
261        });
262
263        encoder.copy_buffer_to_buffer(
264            &voxel_buffer,
265            0,
266            &staging_buffer,
267            0,
268            staging_buffer.size(),
269        );
270
271        self.queue.submit(std::iter::once(encoder.finish()));
272
273        let buffer_slice = staging_buffer.slice(..);
274        let (sender, receiver) = flume::unbounded();
275        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
276            sender.send(result).unwrap();
277        });
278
279        self.device.poll(wgpu::PollType::Wait {
280            submission_index: None,
281            timeout: None,
282        });
283        receiver.recv_async().await.map_err(|_| Error::Gpu("Failed to receive mapping result".into()))?
284            .map_err(|e| Error::Gpu(format!("Buffer mapping failed: {:?}", e)))?;
285
286        let data = buffer_slice.get_mapped_range();
287        let result: Vec<TsdfVoxel> = bytemuck::cast_slice(&data).to_vec();
288        
289        drop(data);
290        staging_buffer.unmap();
291
292        Ok(result)
293    }
294
295    /// Extract point cloud from TSDF volume using marching cubes
296    pub async fn tsdf_extract_surface(
297        &self,
298        volume: &TsdfVolume,
299        voxels: &[TsdfVoxel],
300        iso_value: f32,
301    ) -> Result<PointCloud<ColoredPoint3f>> {
302        let total_voxels = (volume.resolution[0] * volume.resolution[1] * volume.resolution[2]) as usize;
303        let max_points = std::cmp::min(total_voxels, 1_000_000); // Limit to reasonable size
304        
305        // Create voxel buffer
306        let voxel_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
307            label: Some("TSDF Voxel Buffer"),
308            contents: bytemuck::cast_slice(voxels),
309            usage: wgpu::BufferUsages::STORAGE,
310        });
311
312        // Create output buffers
313        let points_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
314            label: Some("Surface Points Buffer"),
315            size: (max_points * std::mem::size_of::<GpuPoint3f>()) as u64,
316            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
317            mapped_at_creation: false,
318        });
319
320        let point_count_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
321            label: Some("Point Count Buffer"),
322            contents: bytemuck::bytes_of(&0u32),
323            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
324        });
325
326        let params = TsdfParams {
327            voxel_size: volume.voxel_size,
328            truncation_distance: volume.truncation_distance,
329            max_weight: 100.0,
330            iso_value,
331            resolution: volume.resolution,
332            _padding2: 0,
333            origin: [volume.origin.x, volume.origin.y, volume.origin.z],
334            _padding3: 0.0,
335        };
336
337        let params_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
338            label: Some("Surface Extraction Params Buffer"),
339            contents: bytemuck::bytes_of(&params),
340            usage: wgpu::BufferUsages::UNIFORM,
341        });
342
343        // Create compute pipeline
344        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
345            label: Some("Surface Extraction Shader"),
346            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/surface_extraction.wgsl").into()),
347        });
348
349        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
350            label: Some("Surface Extraction Pipeline"),
351            layout: None,
352            module: &shader,
353            entry_point: Some("main"),
354            compilation_options: wgpu::PipelineCompilationOptions::default(),
355            cache: None,
356        });
357
358        // Create bind group
359        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
360            label: Some("Surface Extraction Bind Group"),
361            layout: &pipeline.get_bind_group_layout(0),
362            entries: &[
363                wgpu::BindGroupEntry {
364                    binding: 0,
365                    resource: voxel_buffer.as_entire_binding(),
366                },
367                wgpu::BindGroupEntry {
368                    binding: 1,
369                    resource: points_buffer.as_entire_binding(),
370                },
371                wgpu::BindGroupEntry {
372                    binding: 2,
373                    resource: params_buffer.as_entire_binding(),
374                },
375                wgpu::BindGroupEntry {
376                    binding: 3,
377                    resource: point_count_buffer.as_entire_binding(),
378                },
379            ],
380        });
381
382        // Create staging buffer for reading back results
383        let point_count_staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
384            label: Some("Point Count Staging Buffer"),
385            size: std::mem::size_of::<u32>() as u64,
386            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
387            mapped_at_creation: false,
388        });
389
390        // Dispatch compute shader
391        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
392            label: Some("Surface Extraction Encoder"),
393        });
394
395        {
396            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
397                label: Some("Surface Extraction Pass"),
398                timestamp_writes: None,
399            });
400
401            compute_pass.set_pipeline(&pipeline);
402            compute_pass.set_bind_group(0, &bind_group, &[]);
403            compute_pass.dispatch_workgroups(
404                (volume.resolution[0] + 3) / 4,
405                (volume.resolution[1] + 3) / 4,
406                (volume.resolution[2] + 3) / 4,
407            );
408        }
409
410        // Copy point count to staging buffer
411        encoder.copy_buffer_to_buffer(
412            &point_count_buffer,
413            0,
414            &point_count_staging_buffer,
415            0,
416            std::mem::size_of::<u32>() as u64,
417        );
418
419        self.queue.submit(Some(encoder.finish()));
420
421        // Read point count
422        let point_count_slice = point_count_staging_buffer.slice(..);
423        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
424        point_count_slice.map_async(wgpu::MapMode::Read, move |result| {
425            tx.send(result).unwrap();
426        });
427        self.device.poll(wgpu::PollType::Wait {
428            submission_index: None,
429            timeout: None,
430        });
431        rx.receive().await.unwrap()?;
432
433        let mapped_range = point_count_slice.get_mapped_range();
434        let point_count = bytemuck::cast_slice::<u8, u32>(mapped_range.as_ref())[0] as usize;
435        drop(mapped_range);
436        point_count_staging_buffer.unmap();
437
438        if point_count == 0 {
439            return Ok(PointCloud { points: Vec::new() });
440        }
441
442        // Create staging buffer for points
443        let points_staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
444            label: Some("Points Staging Buffer"),
445            size: (point_count * std::mem::size_of::<GpuPoint3f>()) as u64,
446            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
447            mapped_at_creation: false,
448        });
449
450        // Copy points to staging buffer
451        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
452            label: Some("Points Copy Encoder"),
453        });
454
455        encoder.copy_buffer_to_buffer(
456            &points_buffer,
457            0,
458            &points_staging_buffer,
459            0,
460            (point_count * std::mem::size_of::<GpuPoint3f>()) as u64,
461        );
462
463        self.queue.submit(Some(encoder.finish()));
464
465        // Read points
466        let points_slice = points_staging_buffer.slice(..);
467        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
468        points_slice.map_async(wgpu::MapMode::Read, move |result| {
469            tx.send(result).unwrap();
470        });
471        self.device.poll(wgpu::PollType::Wait {
472            submission_index: None,
473            timeout: None,
474        });
475        rx.receive().await.unwrap()?;
476
477        let mapped_range = points_slice.get_mapped_range();
478        let gpu_points = bytemuck::cast_slice::<u8, GpuPoint3f>(mapped_range.as_ref());
479        let mut points = Vec::with_capacity(point_count);
480
481        for gpu_point in gpu_points.iter().take(point_count) {
482            points.push(ColoredPoint3f {
483                position: Point3::new(gpu_point.x, gpu_point.y, gpu_point.z),
484                color: [gpu_point.r as u8, gpu_point.g as u8, gpu_point.b as u8],
485            });
486        }
487
488        drop(mapped_range);  // Explicitly drop the mapped range before unmapping
489        points_staging_buffer.unmap();
490
491        Ok(PointCloud { points })
492    }
493}
494
495impl TsdfVolumeGpu {
496    /// Creates a new TSDF volume on the GPU.
497    pub fn new(gpu: &GpuContext, volume_params: TsdfVolume) -> Self {
498        let total_voxels = (volume_params.resolution[0] * volume_params.resolution[1] * volume_params.resolution[2]) as usize;
499        
500        // Initialize voxels with default values
501        let initial_voxels = vec![TsdfVoxel {
502            tsdf_value: 1.0,
503            weight: 0.0,
504            color_r: 0,
505            color_g: 0,
506            color_b: 0,
507            _padding1: 0,
508            _padding2: 0,
509            _padding3: 0,
510        }; total_voxels];
511
512        let voxel_buffer = gpu.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
513            label: Some("TSDF Voxel Buffer"),
514            contents: bytemuck::cast_slice(&initial_voxels),
515            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
516        });
517
518        Self {
519            volume: volume_params,
520            voxel_buffer,
521        }
522    }
523
524    /// Integrates a depth image into the TSDF volume.
525    pub async fn integrate(
526        &self,
527        gpu: &GpuContext,
528        depth_image: &[f32],
529        color_image: Option<&[u8]>, // RGB color data
530        camera_pose: &Matrix4<f32>,
531        intrinsics: &CameraIntrinsics,
532    ) -> Result<()> {
533        // Create buffers for depth, color, transform, and parameters
534        let depth_buffer = gpu.create_buffer_init("TSDF Depth Buffer", depth_image, wgpu::BufferUsages::STORAGE);
535
536        let color_buffer = if let Some(data) = color_image {
537            gpu.create_buffer_init("TSDF Color Buffer", data, wgpu::BufferUsages::STORAGE)
538        } else {
539            // Create a dummy buffer if no color image is provided
540            gpu.create_buffer_init("TSDF Dummy Color Buffer", &[0u32; 4], wgpu::BufferUsages::STORAGE)
541        };
542
543        // Convert camera transform to world-to-camera matrix (inverse of camera pose)
544        let world_to_camera = camera_pose.try_inverse()
545            .ok_or_else(|| Error::Gpu("Failed to invert camera pose matrix".into()))?;
546        
547        let mut camera_transform = [[0.0f32; 4]; 4];
548        for i in 0..4 {
549            for j in 0..4 {
550                camera_transform[i][j] = world_to_camera[(i, j)];
551            }
552        }
553
554        let transform_buffer = gpu.create_buffer_init(
555            "TSDF Transform Buffer",
556            &[camera_transform],
557            wgpu::BufferUsages::UNIFORM,
558        );
559
560        let intrinsics_buffer = gpu.create_buffer_init(
561            "TSDF Intrinsics Buffer",
562            &[*intrinsics],
563            wgpu::BufferUsages::UNIFORM,
564        );
565
566        let params = TsdfParams {
567            voxel_size: self.volume.voxel_size,
568            truncation_distance: self.volume.truncation_distance,
569            max_weight: 100.0,
570            iso_value: 0.0,
571            resolution: self.volume.resolution,
572            _padding2: 0,
573            origin: [self.volume.origin.x, self.volume.origin.y, self.volume.origin.z],
574            _padding3: 0.0,
575        };
576        let params_buffer = gpu.create_buffer_init("TSDF Params Buffer", &[params], wgpu::BufferUsages::UNIFORM);
577
578        // Create compute pipeline
579        let shader = gpu.create_shader_module("TSDF Integration Shader", include_str!("shaders/tsdf_integration.wgsl"));
580        let pipeline = gpu.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
581            label: Some("TSDF Integration Pipeline"),
582            layout: None,
583            module: &shader,
584            entry_point: Some("main"),
585            compilation_options: wgpu::PipelineCompilationOptions::default(),
586            cache: None,
587        });
588
589        // Create bind group
590        let bind_group_entries = vec![
591            wgpu::BindGroupEntry {
592                binding: 0,
593                resource: self.voxel_buffer.as_entire_binding(),
594            },
595            wgpu::BindGroupEntry {
596                binding: 1,
597                resource: depth_buffer.as_entire_binding(),
598            },
599            wgpu::BindGroupEntry {
600                binding: 2,
601                resource: transform_buffer.as_entire_binding(),
602            },
603            wgpu::BindGroupEntry {
604                binding: 3,
605                resource: intrinsics_buffer.as_entire_binding(),
606            },
607            wgpu::BindGroupEntry {
608                binding: 4,
609                resource: params_buffer.as_entire_binding(),
610            },
611            wgpu::BindGroupEntry {
612                binding: 5,
613                resource: color_buffer.as_entire_binding(),
614            },
615        ];
616
617        let bind_group = gpu.create_bind_group("TSDF Integration Bind Group", &pipeline.get_bind_group_layout(0), &bind_group_entries);
618
619        // Dispatch compute shader
620        let mut encoder = gpu.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
621            label: Some("TSDF Integration Encoder"),
622        });
623
624        {
625            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
626                label: Some("TSDF Integration Pass"),
627                timestamp_writes: None,
628            });
629
630            compute_pass.set_pipeline(&pipeline);
631            compute_pass.set_bind_group(0, &bind_group, &[]);
632            
633            // Dispatch with 4x4x4 workgroups
634            let workgroup_size = 4;
635            let dispatch_x = (self.volume.resolution[0] + workgroup_size - 1) / workgroup_size;
636            let dispatch_y = (self.volume.resolution[1] + workgroup_size - 1) / workgroup_size;
637            let dispatch_z = (self.volume.resolution[2] + workgroup_size - 1) / workgroup_size;
638            
639            println!("Dispatching compute shader with {} x {} x {} workgroups", dispatch_x, dispatch_y, dispatch_z);
640            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, dispatch_z);
641        }
642
643        gpu.queue.submit(std::iter::once(encoder.finish()));
644        Ok(())
645    }
646
647    /// Downloads the TSDF voxel data from the GPU.
648    pub async fn download_voxels(&self, gpu: &GpuContext) -> Result<Vec<TsdfVoxel>> {
649        let total_voxels = (self.volume.resolution[0] * self.volume.resolution[1] * self.volume.resolution[2]) as usize;
650        let buffer_size = (total_voxels * std::mem::size_of::<TsdfVoxel>()) as u64;
651
652        let staging_buffer = gpu.device.create_buffer(&wgpu::BufferDescriptor {
653            label: Some("TSDF Staging Buffer"),
654            size: buffer_size,
655            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
656            mapped_at_creation: false,
657        });
658
659        let mut encoder = gpu.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
660            label: Some("TSDF Download Encoder"),
661        });
662
663        encoder.copy_buffer_to_buffer(
664            &self.voxel_buffer,
665            0,
666            &staging_buffer,
667            0,
668            buffer_size,
669        );
670
671        gpu.queue.submit(std::iter::once(encoder.finish()));
672
673        let buffer_slice = staging_buffer.slice(..);
674        let (sender, receiver) = flume::unbounded();
675        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
676            sender.send(result).unwrap();
677        });
678
679        gpu.device.poll(wgpu::PollType::Wait {
680            submission_index: None,
681            timeout: None,
682        });
683        receiver.recv_async().await.map_err(|_| Error::Gpu("Failed to receive mapping result".into()))??;
684
685        let data = buffer_slice.get_mapped_range();
686        let result: Vec<TsdfVoxel> = bytemuck::cast_slice(&data).to_vec();
687        
688        drop(data);
689        staging_buffer.unmap();
690
691        Ok(result)
692    }
693
694    /// Extract point cloud from TSDF volume using marching cubes
695    pub async fn extract_surface(&self, gpu: &GpuContext, iso_value: f32) -> Result<PointCloud<ColoredPoint3f>> {
696        let voxels = self.download_voxels(gpu).await?;
697        gpu.tsdf_extract_surface(&self.volume, &voxels, iso_value).await
698    }
699}
700
701/// Create a new TSDF volume with specified parameters
702pub fn create_tsdf_volume(
703    voxel_size: f32,
704    truncation_distance: f32,
705    resolution: [u32; 3],
706    origin: Point3<f32>,
707) -> TsdfVolume {
708    TsdfVolume {
709        voxel_size,
710        truncation_distance,
711        resolution,
712        origin,
713    }
714}
715
716/// GPU-accelerated TSDF integration from depth image
717pub async fn gpu_tsdf_integrate(
718    gpu_context: &GpuContext,
719    volume: &mut TsdfVolume,
720    depth_image: &[f32],
721    color_image: Option<&[u8]>,
722    camera_pose: &Matrix4<f32>,
723    intrinsics: &CameraIntrinsics,
724) -> Result<Vec<TsdfVoxel>> {
725    gpu_context.tsdf_integrate(volume, depth_image, color_image, camera_pose, intrinsics).await
726}
727
728/// GPU-accelerated surface extraction from TSDF volume
729pub async fn gpu_tsdf_extract_surface(
730    gpu_context: &GpuContext,
731    volume: &TsdfVolume,
732    voxels: &[TsdfVoxel],
733    iso_value: f32,
734) -> Result<PointCloud<ColoredPoint3f>> {
735    gpu_context.tsdf_extract_surface(volume, voxels, iso_value).await
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741    use crate::device::GpuContext;
742    use nalgebra::{Matrix4, Point3};
743    use approx::assert_relative_eq;
744
745    /// Try to create a GPU context, return None if not available
746    async fn try_create_gpu_context() -> Option<GpuContext> {
747        match GpuContext::new().await {
748            Ok(gpu) => Some(gpu),
749            Err(_) => {
750                println!("⚠️  GPU not available, skipping GPU-dependent test");
751                None
752            }
753        }
754    }
755
756    /// Create simple depth image for basic testing
757    fn create_simple_depth_image(width: u32, height: u32, depth: f32) -> Vec<f32> {
758        vec![depth; (width * height) as usize]
759    }
760
761    /// Create a basic camera setup for testing
762    fn create_test_camera() -> CameraIntrinsics {
763        CameraIntrinsics {
764            fx: 525.0,
765            fy: 525.0,
766            cx: 319.5,
767            cy: 239.5,
768            width: 640,
769            height: 480,
770            depth_scale: 1.0,
771            _padding: 0.0,
772        }
773    }
774
775    /// Create identity camera pose
776    fn create_identity_pose() -> Matrix4<f32> {
777        Matrix4::new(
778            1.0, 0.0, 0.0, 0.0,
779            0.0, 1.0, 0.0, 0.0,
780            0.0, 0.0, 1.0, 0.0,
781            0.0, 0.0, 0.0, 1.0,
782        )
783    }
784
785    #[test]
786    fn test_tsdf_basic_integration() {
787        pollster::block_on(async {
788            let Some(gpu) = try_create_gpu_context().await else {
789                return;
790            };
791
792            // Create a simple TSDF volume
793            let voxel_size = 0.02; // 2cm voxels for faster processing
794            let truncation_distance = 0.1; 
795            let resolution = [32, 32, 32]; // Smaller resolution for speed
796            let origin = Point3::new(-0.32, -0.32, 0.0);
797
798            let volume_params = create_tsdf_volume(
799                voxel_size,
800                truncation_distance,
801                resolution,
802                origin,
803            );
804            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
805
806            // Create simple depth image with constant depth
807            let intrinsics = create_test_camera();
808            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.5);
809            let camera_pose = create_identity_pose();
810
811            // Test integration
812            let result = tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics).await;
813            assert!(result.is_ok(), "TSDF integration should succeed");
814
815            // Test voxel download
816            let voxels = tsdf_volume_gpu.download_voxels(&gpu).await.unwrap();
817            assert_eq!(voxels.len(), (32 * 32 * 32) as usize, "Should have correct number of voxels");
818
819            // Check that some voxels have been updated
820            let updated_voxels = voxels.iter().filter(|v| v.weight > 0.0).count();
821            assert!(updated_voxels > 0, "Some voxels should have been updated");
822
823            println!("✓ Basic integration test passed: {} voxels updated", updated_voxels);
824        });
825    }
826
827    #[test]
828    fn test_tsdf_surface_extraction() {
829        pollster::block_on(async {
830            let Some(gpu) = try_create_gpu_context().await else {
831                return;
832            };
833
834            // Create TSDF volume 
835            let voxel_size = 0.02;
836            let truncation_distance = 0.1;
837            let resolution = [32, 32, 32];
838            let origin = Point3::new(-0.32, -0.32, 0.0);
839
840            let volume_params = create_tsdf_volume(
841                voxel_size,
842                truncation_distance,
843                resolution,
844                origin,
845            );
846            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
847
848            // Integrate a simple depth image
849            let intrinsics = create_test_camera();
850            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.3);
851            let camera_pose = create_identity_pose();
852
853            tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics)
854                .await
855                .unwrap();
856
857            // Extract surface
858            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
859            
860            // Should extract some points
861            assert!(!point_cloud.points.is_empty(), "Should extract surface points");
862            
863            // Points should be in reasonable Z range around the depth value
864            let avg_z = point_cloud.points.iter()
865                .map(|p| p.position.z)
866                .sum::<f32>() / point_cloud.points.len() as f32;
867            
868            assert!(avg_z > 0.2 && avg_z < 0.4, "Average Z should be near depth value of 0.3");
869            
870            println!("✓ Surface extraction test passed: {} points extracted, avg Z: {:.3}", 
871                     point_cloud.points.len(), avg_z);
872        });
873    }
874
875    #[test]
876    fn test_tsdf_multiple_integrations() {
877        pollster::block_on(async {
878            let Some(gpu) = try_create_gpu_context().await else {
879                return;
880            };
881
882            // Create TSDF volume
883            let voxel_size = 0.02;
884            let truncation_distance = 0.1;
885            let resolution = [32, 32, 32];
886            let origin = Point3::new(-0.32, -0.32, 0.0);
887
888            let volume_params = create_tsdf_volume(
889                voxel_size,
890                truncation_distance,
891                resolution,
892                origin,
893            );
894            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
895
896            let intrinsics = create_test_camera();
897            let camera_pose = create_identity_pose();
898
899            // Integrate multiple depth images
900            let depths = [0.25, 0.3, 0.35];
901            for &depth in &depths {
902                let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, depth);
903                tsdf_volume_gpu.integrate(&gpu, &depth_image, None, &camera_pose, &intrinsics)
904                    .await
905                    .unwrap();
906            }
907
908            // Check voxel weights have increased
909            let voxels = tsdf_volume_gpu.download_voxels(&gpu).await.unwrap();
910            let max_weight = voxels.iter().map(|v| v.weight).fold(0.0, f32::max);
911            assert!(max_weight > 1.0, "Multiple integrations should increase voxel weights");
912
913            // Extract surface
914            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
915            assert!(!point_cloud.points.is_empty(), "Should extract surface after multiple integrations");
916
917            println!("✓ Multiple integration test passed: max weight {:.1}, {} points extracted", 
918                     max_weight, point_cloud.points.len());
919        });
920    }
921
922    #[test]
923    fn test_tsdf_coordinate_system() {
924        pollster::block_on(async {
925            let Some(_gpu) = try_create_gpu_context().await else {
926                return;
927            };
928
929            // Test basic coordinate system consistency
930            let voxel_size = 0.02;
931            let resolution = [32, 32, 32];
932            let origin = Point3::new(-0.32, -0.32, 0.0);
933
934            // Check volume bounds
935            let max_coord = Point3::new(
936                origin.x + (resolution[0] as f32) * voxel_size,
937                origin.y + (resolution[1] as f32) * voxel_size,
938                origin.z + (resolution[2] as f32) * voxel_size,
939            );
940
941            assert_relative_eq!(max_coord.x, 0.32, epsilon = 0.01);
942            assert_relative_eq!(max_coord.y, 0.32, epsilon = 0.01);
943            assert_relative_eq!(max_coord.z, 0.64, epsilon = 0.01);
944
945            // Test camera transforms
946            let camera_pose = create_identity_pose();
947            let world_to_camera = camera_pose.try_inverse().unwrap();
948            
949            let test_point = Point3::new(0.1, 0.2, 0.3);
950            let camera_point = world_to_camera.transform_point(&test_point);
951            
952            // For identity transform, should be the same
953            assert_relative_eq!(test_point.x, camera_point.x, epsilon = 0.001);
954            assert_relative_eq!(test_point.y, camera_point.y, epsilon = 0.001);
955            assert_relative_eq!(test_point.z, camera_point.z, epsilon = 0.001);
956
957            println!("✓ Coordinate system test passed");
958        });
959    }
960
961    #[test]
962    fn test_tsdf_color_integration() {
963        pollster::block_on(async {
964            let Some(gpu) = try_create_gpu_context().await else {
965                return;
966            };
967
968            // Create TSDF volume
969            let voxel_size = 0.02;
970            let truncation_distance = 0.1;
971            let resolution = [32, 32, 32];
972            let origin = Point3::new(-0.32, -0.32, 0.0);
973
974            let volume_params = create_tsdf_volume(
975                voxel_size,
976                truncation_distance,
977                resolution,
978                origin,
979            );
980            let tsdf_volume_gpu = TsdfVolumeGpu::new(&gpu, volume_params);
981
982            // Create depth and color images
983            let intrinsics = create_test_camera();
984            let depth_image = create_simple_depth_image(intrinsics.width, intrinsics.height, 0.3);
985            
986            // Simple red color image
987            let pixel_count = (intrinsics.width * intrinsics.height) as usize;
988            let mut color_image = Vec::with_capacity(pixel_count * 3);
989            for _ in 0..pixel_count {
990                color_image.extend_from_slice(&[255u8, 0u8, 0u8]); // RGB: red
991            }
992            
993            let camera_pose = create_identity_pose();
994
995            // Integrate with color
996            tsdf_volume_gpu.integrate(&gpu, &depth_image, Some(&color_image), &camera_pose, &intrinsics)
997                .await
998                .unwrap();
999
1000            // Extract surface
1001            let point_cloud = tsdf_volume_gpu.extract_surface(&gpu, 0.0).await.unwrap();
1002            
1003            assert!(!point_cloud.points.is_empty(), "Should extract colored surface points");
1004            
1005            // Check that some points have red color
1006            let red_points = point_cloud.points.iter()
1007                .filter(|p| p.color[0] > 200)
1008                .count();
1009            
1010            assert!(red_points > 0, "Some points should have red color");
1011
1012            println!("✓ Color integration test passed: {} red points", red_points);
1013        });
1014    }
1015}