Skip to main content

sift/
gpu_sift_v2.rs

1// Full GPU SIFT implementation using texture-based pipeline
2// Targeting <20ms performance like VulkanSift
3
4use crate::keypoints::KeyPoint;
5use std::sync::Arc;
6use wgpu;
7
8pub struct GpuSiftConfigV2 {
9    pub octaves: u32,
10    pub scales_per_octave: u32, // S in SIFT paper (typically 3)
11    pub base_sigma: f32,
12    pub contrast_threshold: f32,
13    pub edge_threshold: f32,
14    pub max_keypoints: u32, // Maximum number of keypoints to detect
15}
16
17impl Default for GpuSiftConfigV2 {
18    fn default() -> Self {
19        Self {
20            octaves: 4,
21            scales_per_octave: 3, // Results in S+3=6 Gaussian images per octave
22            base_sigma: 1.6,
23            contrast_threshold: 0.04,
24            edge_threshold: 10.0,
25            max_keypoints: 4096, // Reasonable default for most images
26        }
27    }
28}
29
30/// Full GPU SIFT using texture-based pipeline
31pub struct GpuSiftV2 {
32    device: Arc<wgpu::Device>,
33    queue: Arc<wgpu::Queue>,
34
35    // Pipelines
36    blur_h_pipeline: wgpu::ComputePipeline,
37    blur_v_pipeline: wgpu::ComputePipeline,
38    dog_pipeline: wgpu::ComputePipeline,
39    downsample_pipeline: wgpu::ComputePipeline,
40    extrema_pipeline: wgpu::ComputePipeline,
41    orientation_pipeline: wgpu::ComputePipeline,
42    descriptor_pipeline: wgpu::ComputePipeline,
43    prepare_orient_indirect_pipeline: wgpu::ComputePipeline,
44    prepare_desc_indirect_pipeline: wgpu::ComputePipeline,
45
46    // Bind group layouts
47    blur_bind_group_layout: wgpu::BindGroupLayout,
48    dog_bind_group_layout: wgpu::BindGroupLayout,
49    downsample_bind_group_layout: wgpu::BindGroupLayout,
50    extrema_bind_group_layout: wgpu::BindGroupLayout,
51    prepare_indirect_bind_group_layout: wgpu::BindGroupLayout,
52
53    // Sampler
54    linear_sampler: wgpu::Sampler,
55    nearest_sampler: wgpu::Sampler,
56
57    // Preallocated resources (resized per image)
58    resources: Option<GpuResources>,
59
60    config: GpuSiftConfigV2,
61}
62
63struct GpuResources {
64    width: u32,
65    height: u32,
66
67    // Gaussian pyramid textures: [octave][scale]
68    // Each octave has scales_per_octave + 3 images
69    gaussian_textures: Vec<Vec<wgpu::Texture>>,
70    gaussian_views: Vec<Vec<wgpu::TextureView>>,
71
72    // DoG pyramid: [octave][scale] - one less than Gaussian per octave
73    dog_textures: Vec<Vec<wgpu::Texture>>,
74    dog_views: Vec<Vec<wgpu::TextureView>>,
75
76    // Temporary textures for blur passes
77    temp_textures: Vec<wgpu::Texture>,
78    temp_views: Vec<wgpu::TextureView>,
79
80    // Keypoint buffers
81    keypoint_counter: wgpu::Buffer,
82    keypoints: wgpu::Buffer,
83    oriented_keypoint_counter: wgpu::Buffer,
84    oriented_keypoints: wgpu::Buffer,
85    descriptors: wgpu::Buffer,
86
87    // Indirect dispatch buffers (12 bytes each: x, y, z workgroups)
88    orientation_indirect: wgpu::Buffer,
89    descriptor_indirect: wgpu::Buffer,
90
91    // Readback
92    readback_counters: wgpu::Buffer,
93    readback_keypoints: wgpu::Buffer,
94    readback_descriptors: wgpu::Buffer,
95}
96
97impl GpuSiftV2 {
98    pub async fn new(config: GpuSiftConfigV2) -> Result<Self, Box<dyn std::error::Error>> {
99        #[cfg(not(target_arch = "wasm32"))]
100        let instance = wgpu::Instance::default();
101        #[cfg(target_arch = "wasm32")]
102        let instance = wgpu::Instance::default(); // Browser WebGPU default is usually fine (backends: BROWSER_WEBGPU)
103
104        let adapter = instance
105            .request_adapter(&wgpu::RequestAdapterOptions {
106                power_preference: wgpu::PowerPreference::HighPerformance,
107                compatible_surface: None,
108                force_fallback_adapter: false,
109            })
110            .await
111            .map_err(|e| format!("Failed to find an appropriate adapter: {:?}", e))?;
112
113        let (device, queue) = adapter
114            .request_device(&wgpu::DeviceDescriptor {
115                label: Some("SIFT V2 Device"),
116                required_features: wgpu::Features::empty(),
117                required_limits: wgpu::Limits::downlevel_defaults(),
118                ..Default::default()
119            })
120            .await
121            .map_err(|e| format!("Failed to create device: {:?}", e))?;
122
123        let device = Arc::new(device);
124        let queue = Arc::new(queue);
125
126        // Create samplers
127        // Note: R32Float doesn't support filtering on most hardware, so use Nearest
128        let linear_sampler = device.create_sampler(&wgpu::SamplerDescriptor {
129            label: Some("Linear Sampler"),
130            address_mode_u: wgpu::AddressMode::ClampToEdge,
131            address_mode_v: wgpu::AddressMode::ClampToEdge,
132            mag_filter: wgpu::FilterMode::Nearest, // R32Float doesn't support linear
133            min_filter: wgpu::FilterMode::Nearest,
134            ..Default::default()
135        });
136
137        let nearest_sampler = device.create_sampler(&wgpu::SamplerDescriptor {
138            label: Some("Nearest Sampler"),
139            address_mode_u: wgpu::AddressMode::ClampToEdge,
140            address_mode_v: wgpu::AddressMode::ClampToEdge,
141            mag_filter: wgpu::FilterMode::Nearest,
142            min_filter: wgpu::FilterMode::Nearest,
143            ..Default::default()
144        });
145
146        // Create bind group layouts
147        let blur_bind_group_layout =
148            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
149                label: Some("Blur BGL"),
150                entries: &[
151                    // Input texture - R32Float is not filterable by default on some hardware
152                    // Use unfilterable sample type and separate sampler
153                    wgpu::BindGroupLayoutEntry {
154                        binding: 0,
155                        visibility: wgpu::ShaderStages::COMPUTE,
156                        ty: wgpu::BindingType::Texture {
157                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
158                            view_dimension: wgpu::TextureViewDimension::D2,
159                            multisampled: false,
160                        },
161                        count: None,
162                    },
163                    // Sampler - use NonFiltering since R32Float doesn't support filtering
164                    wgpu::BindGroupLayoutEntry {
165                        binding: 1,
166                        visibility: wgpu::ShaderStages::COMPUTE,
167                        ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::NonFiltering),
168                        count: None,
169                    },
170                    // Output storage texture
171                    wgpu::BindGroupLayoutEntry {
172                        binding: 2,
173                        visibility: wgpu::ShaderStages::COMPUTE,
174                        ty: wgpu::BindingType::StorageTexture {
175                            access: wgpu::StorageTextureAccess::WriteOnly,
176                            format: wgpu::TextureFormat::R32Float,
177                            view_dimension: wgpu::TextureViewDimension::D2,
178                        },
179                        count: None,
180                    },
181                    // Params uniform
182                    wgpu::BindGroupLayoutEntry {
183                        binding: 3,
184                        visibility: wgpu::ShaderStages::COMPUTE,
185                        ty: wgpu::BindingType::Buffer {
186                            ty: wgpu::BufferBindingType::Uniform,
187                            has_dynamic_offset: false,
188                            min_binding_size: None,
189                        },
190                        count: None,
191                    },
192                    // Kernel weights
193                    wgpu::BindGroupLayoutEntry {
194                        binding: 4,
195                        visibility: wgpu::ShaderStages::COMPUTE,
196                        ty: wgpu::BindingType::Buffer {
197                            ty: wgpu::BufferBindingType::Storage { read_only: true },
198                            has_dynamic_offset: false,
199                            min_binding_size: None,
200                        },
201                        count: None,
202                    },
203                ],
204            });
205
206        let dog_bind_group_layout =
207            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
208                label: Some("DoG BGL"),
209                entries: &[
210                    // Input texture 1 (higher blur)
211                    wgpu::BindGroupLayoutEntry {
212                        binding: 0,
213                        visibility: wgpu::ShaderStages::COMPUTE,
214                        ty: wgpu::BindingType::Texture {
215                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
216                            view_dimension: wgpu::TextureViewDimension::D2,
217                            multisampled: false,
218                        },
219                        count: None,
220                    },
221                    // Input texture 2 (lower blur)
222                    wgpu::BindGroupLayoutEntry {
223                        binding: 1,
224                        visibility: wgpu::ShaderStages::COMPUTE,
225                        ty: wgpu::BindingType::Texture {
226                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
227                            view_dimension: wgpu::TextureViewDimension::D2,
228                            multisampled: false,
229                        },
230                        count: None,
231                    },
232                    // Output storage texture
233                    wgpu::BindGroupLayoutEntry {
234                        binding: 2,
235                        visibility: wgpu::ShaderStages::COMPUTE,
236                        ty: wgpu::BindingType::StorageTexture {
237                            access: wgpu::StorageTextureAccess::WriteOnly,
238                            format: wgpu::TextureFormat::R32Float,
239                            view_dimension: wgpu::TextureViewDimension::D2,
240                        },
241                        count: None,
242                    },
243                    // Params
244                    wgpu::BindGroupLayoutEntry {
245                        binding: 3,
246                        visibility: wgpu::ShaderStages::COMPUTE,
247                        ty: wgpu::BindingType::Buffer {
248                            ty: wgpu::BufferBindingType::Uniform,
249                            has_dynamic_offset: false,
250                            min_binding_size: None,
251                        },
252                        count: None,
253                    },
254                ],
255            });
256
257        let downsample_bind_group_layout =
258            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
259                label: Some("Downsample BGL"),
260                entries: &[
261                    wgpu::BindGroupLayoutEntry {
262                        binding: 0,
263                        visibility: wgpu::ShaderStages::COMPUTE,
264                        ty: wgpu::BindingType::Texture {
265                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
266                            view_dimension: wgpu::TextureViewDimension::D2,
267                            multisampled: false,
268                        },
269                        count: None,
270                    },
271                    wgpu::BindGroupLayoutEntry {
272                        binding: 1,
273                        visibility: wgpu::ShaderStages::COMPUTE,
274                        ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::NonFiltering),
275                        count: None,
276                    },
277                    wgpu::BindGroupLayoutEntry {
278                        binding: 2,
279                        visibility: wgpu::ShaderStages::COMPUTE,
280                        ty: wgpu::BindingType::StorageTexture {
281                            access: wgpu::StorageTextureAccess::WriteOnly,
282                            format: wgpu::TextureFormat::R32Float,
283                            view_dimension: wgpu::TextureViewDimension::D2,
284                        },
285                        count: None,
286                    },
287                    wgpu::BindGroupLayoutEntry {
288                        binding: 3,
289                        visibility: wgpu::ShaderStages::COMPUTE,
290                        ty: wgpu::BindingType::Buffer {
291                            ty: wgpu::BufferBindingType::Uniform,
292                            has_dynamic_offset: false,
293                            min_binding_size: None,
294                        },
295                        count: None,
296                    },
297                ],
298            });
299
300        let extrema_bind_group_layout =
301            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
302                label: Some("Extrema BGL"),
303                entries: &[
304                    // DoG below
305                    wgpu::BindGroupLayoutEntry {
306                        binding: 0,
307                        visibility: wgpu::ShaderStages::COMPUTE,
308                        ty: wgpu::BindingType::Texture {
309                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
310                            view_dimension: wgpu::TextureViewDimension::D2,
311                            multisampled: false,
312                        },
313                        count: None,
314                    },
315                    // DoG current
316                    wgpu::BindGroupLayoutEntry {
317                        binding: 1,
318                        visibility: wgpu::ShaderStages::COMPUTE,
319                        ty: wgpu::BindingType::Texture {
320                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
321                            view_dimension: wgpu::TextureViewDimension::D2,
322                            multisampled: false,
323                        },
324                        count: None,
325                    },
326                    // DoG above
327                    wgpu::BindGroupLayoutEntry {
328                        binding: 2,
329                        visibility: wgpu::ShaderStages::COMPUTE,
330                        ty: wgpu::BindingType::Texture {
331                            sample_type: wgpu::TextureSampleType::Float { filterable: false },
332                            view_dimension: wgpu::TextureViewDimension::D2,
333                            multisampled: false,
334                        },
335                        count: None,
336                    },
337                    // Params
338                    wgpu::BindGroupLayoutEntry {
339                        binding: 3,
340                        visibility: wgpu::ShaderStages::COMPUTE,
341                        ty: wgpu::BindingType::Buffer {
342                            ty: wgpu::BufferBindingType::Uniform,
343                            has_dynamic_offset: false,
344                            min_binding_size: None,
345                        },
346                        count: None,
347                    },
348                    // Keypoint counter
349                    wgpu::BindGroupLayoutEntry {
350                        binding: 4,
351                        visibility: wgpu::ShaderStages::COMPUTE,
352                        ty: wgpu::BindingType::Buffer {
353                            ty: wgpu::BufferBindingType::Storage { read_only: false },
354                            has_dynamic_offset: false,
355                            min_binding_size: None,
356                        },
357                        count: None,
358                    },
359                    // Keypoints output
360                    wgpu::BindGroupLayoutEntry {
361                        binding: 5,
362                        visibility: wgpu::ShaderStages::COMPUTE,
363                        ty: wgpu::BindingType::Buffer {
364                            ty: wgpu::BufferBindingType::Storage { read_only: false },
365                            has_dynamic_offset: false,
366                            min_binding_size: None,
367                        },
368                        count: None,
369                    },
370                ],
371            });
372
373        // Load shaders
374        let blur_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
375            label: Some("Blur Shader V2"),
376            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/gpu_blur.wgsl").into()),
377        });
378
379        let dog_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
380            label: Some("DoG Shader V2"),
381            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/gpu_dog.wgsl").into()),
382        });
383
384        let downsample_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
385            label: Some("Downsample Shader V2"),
386            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/gpu_downsample.wgsl").into()),
387        });
388
389        let extrema_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
390            label: Some("Extrema Shader V2"),
391            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/gpu_extrema.wgsl").into()),
392        });
393
394        let orientation_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
395            label: Some("Orientation Shader V2"),
396            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/gpu_orientation.wgsl").into()),
397        });
398
399        let descriptor_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
400            label: Some("Descriptor Shader V2"),
401            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/gpu_descriptor.wgsl").into()),
402        });
403
404        let prepare_indirect_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
405            label: Some("Prepare Indirect Shader"),
406            source: wgpu::ShaderSource::Wgsl(
407                include_str!("shaders/gpu_prepare_indirect.wgsl").into(),
408            ),
409        });
410
411        // Prepare indirect bind group layout
412        let prepare_indirect_bind_group_layout =
413            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
414                label: Some("Prepare Indirect BGL"),
415                entries: &[
416                    // Counter (read via atomicLoad, needs read_write for atomic)
417                    wgpu::BindGroupLayoutEntry {
418                        binding: 0,
419                        visibility: wgpu::ShaderStages::COMPUTE,
420                        ty: wgpu::BindingType::Buffer {
421                            ty: wgpu::BufferBindingType::Storage { read_only: false },
422                            has_dynamic_offset: false,
423                            min_binding_size: None,
424                        },
425                        count: None,
426                    },
427                    // Indirect dispatch buffer (write)
428                    wgpu::BindGroupLayoutEntry {
429                        binding: 1,
430                        visibility: wgpu::ShaderStages::COMPUTE,
431                        ty: wgpu::BindingType::Buffer {
432                            ty: wgpu::BufferBindingType::Storage { read_only: false },
433                            has_dynamic_offset: false,
434                            min_binding_size: None,
435                        },
436                        count: None,
437                    },
438                ],
439            });
440
441        // Create pipelines
442        let blur_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
443            label: Some("Blur Layout"),
444            bind_group_layouts: &[&blur_bind_group_layout],
445            push_constant_ranges: &[],
446        });
447
448        let blur_h_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
449            label: Some("Blur H Pipeline V2"),
450            layout: Some(&blur_layout),
451            module: &blur_shader,
452            entry_point: Some("blur_horizontal"),
453            compilation_options: Default::default(),
454            cache: None,
455        });
456
457        let blur_v_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
458            label: Some("Blur V Pipeline V2"),
459            layout: Some(&blur_layout),
460            module: &blur_shader,
461            entry_point: Some("blur_vertical"),
462            compilation_options: Default::default(),
463            cache: None,
464        });
465
466        let dog_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
467            label: Some("DoG Layout"),
468            bind_group_layouts: &[&dog_bind_group_layout],
469            push_constant_ranges: &[],
470        });
471
472        let dog_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
473            label: Some("DoG Pipeline V2"),
474            layout: Some(&dog_layout),
475            module: &dog_shader,
476            entry_point: Some("compute_dog"),
477            compilation_options: Default::default(),
478            cache: None,
479        });
480
481        let downsample_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
482            label: Some("Downsample Layout"),
483            bind_group_layouts: &[&downsample_bind_group_layout],
484            push_constant_ranges: &[],
485        });
486
487        let downsample_pipeline =
488            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
489                label: Some("Downsample Pipeline V2"),
490                layout: Some(&downsample_layout),
491                module: &downsample_shader,
492                entry_point: Some("downsample_2x"),
493                compilation_options: Default::default(),
494                cache: None,
495            });
496
497        let extrema_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
498            label: Some("Extrema Layout"),
499            bind_group_layouts: &[&extrema_bind_group_layout],
500            push_constant_ranges: &[],
501        });
502
503        let extrema_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
504            label: Some("Extrema Pipeline V2"),
505            layout: Some(&extrema_layout),
506            module: &extrema_shader,
507            entry_point: Some("detect_extrema"),
508            compilation_options: Default::default(),
509            cache: None,
510        });
511
512        // Orientation and descriptor use auto-layout for now
513        let orientation_pipeline =
514            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
515                label: Some("Orientation Pipeline V2"),
516                layout: None,
517                module: &orientation_shader,
518                entry_point: Some("compute_orientation"),
519                compilation_options: Default::default(),
520                cache: None,
521            });
522
523        let descriptor_pipeline =
524            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
525                label: Some("Descriptor Pipeline V2"),
526                layout: None,
527                module: &descriptor_shader,
528                entry_point: Some("compute_descriptor"),
529                compilation_options: Default::default(),
530                cache: None,
531            });
532
533        // Prepare indirect dispatch pipelines
534        let prepare_indirect_layout =
535            device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
536                label: Some("Prepare Indirect Layout"),
537                bind_group_layouts: &[&prepare_indirect_bind_group_layout],
538                push_constant_ranges: &[],
539            });
540
541        let prepare_orient_indirect_pipeline =
542            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
543                label: Some("Prepare Orient Indirect Pipeline"),
544                layout: Some(&prepare_indirect_layout),
545                module: &prepare_indirect_shader,
546                entry_point: Some("prepare_orientation_indirect"),
547                compilation_options: Default::default(),
548                cache: None,
549            });
550
551        let prepare_desc_indirect_pipeline =
552            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
553                label: Some("Prepare Desc Indirect Pipeline"),
554                layout: Some(&prepare_indirect_layout),
555                module: &prepare_indirect_shader,
556                entry_point: Some("prepare_descriptor_indirect"),
557                compilation_options: Default::default(),
558                cache: None,
559            });
560
561        Ok(Self {
562            device,
563            queue,
564            blur_h_pipeline,
565            blur_v_pipeline,
566            dog_pipeline,
567            downsample_pipeline,
568            extrema_pipeline,
569            orientation_pipeline,
570            descriptor_pipeline,
571            prepare_orient_indirect_pipeline,
572            prepare_desc_indirect_pipeline,
573            blur_bind_group_layout,
574            dog_bind_group_layout,
575            downsample_bind_group_layout,
576            extrema_bind_group_layout,
577            prepare_indirect_bind_group_layout,
578            linear_sampler,
579            nearest_sampler,
580            resources: None,
581            config,
582        })
583    }
584
585    fn ensure_resources(&mut self, width: u32, height: u32) {
586        if let Some(ref res) = self.resources {
587            if res.width == width && res.height == height {
588                return;
589            }
590        }
591
592        let scales_per_octave = self.config.scales_per_octave + 3; // Total Gaussian images per octave
593        let dog_scales = scales_per_octave - 1;
594
595        let mut gaussian_textures = Vec::new();
596        let mut gaussian_views = Vec::new();
597        let mut dog_textures = Vec::new();
598        let mut dog_views = Vec::new();
599        let mut temp_textures = Vec::new();
600        let mut temp_views = Vec::new();
601
602        let mut w = width;
603        let mut h = height;
604
605        for octave in 0..self.config.octaves {
606            if w < 8 || h < 8 {
607                break;
608            }
609
610            let mut octave_gaussian = Vec::new();
611            let mut octave_gaussian_views = Vec::new();
612            let mut octave_dog = Vec::new();
613            let mut octave_dog_views = Vec::new();
614
615            // Gaussian textures for this octave
616            for s in 0..scales_per_octave {
617                let tex = self.device.create_texture(&wgpu::TextureDescriptor {
618                    label: Some(&format!("Gaussian O{}S{}", octave, s)),
619                    size: wgpu::Extent3d {
620                        width: w,
621                        height: h,
622                        depth_or_array_layers: 1,
623                    },
624                    mip_level_count: 1,
625                    sample_count: 1,
626                    dimension: wgpu::TextureDimension::D2,
627                    format: wgpu::TextureFormat::R32Float,
628                    usage: wgpu::TextureUsages::TEXTURE_BINDING
629                        | wgpu::TextureUsages::STORAGE_BINDING
630                        | wgpu::TextureUsages::COPY_DST,
631                    view_formats: &[],
632                });
633                let view = tex.create_view(&Default::default());
634                octave_gaussian.push(tex);
635                octave_gaussian_views.push(view);
636            }
637
638            // DoG textures for this octave
639            for d in 0..dog_scales {
640                let tex = self.device.create_texture(&wgpu::TextureDescriptor {
641                    label: Some(&format!("DoG O{}D{}", octave, d)),
642                    size: wgpu::Extent3d {
643                        width: w,
644                        height: h,
645                        depth_or_array_layers: 1,
646                    },
647                    mip_level_count: 1,
648                    sample_count: 1,
649                    dimension: wgpu::TextureDimension::D2,
650                    format: wgpu::TextureFormat::R32Float,
651                    usage: wgpu::TextureUsages::TEXTURE_BINDING
652                        | wgpu::TextureUsages::STORAGE_BINDING,
653                    view_formats: &[],
654                });
655                let view = tex.create_view(&Default::default());
656                octave_dog.push(tex);
657                octave_dog_views.push(view);
658            }
659
660            // Temp texture for blur intermediate
661            let temp = self.device.create_texture(&wgpu::TextureDescriptor {
662                label: Some(&format!("Temp O{}", octave)),
663                size: wgpu::Extent3d {
664                    width: w,
665                    height: h,
666                    depth_or_array_layers: 1,
667                },
668                mip_level_count: 1,
669                sample_count: 1,
670                dimension: wgpu::TextureDimension::D2,
671                format: wgpu::TextureFormat::R32Float,
672                usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::STORAGE_BINDING,
673                view_formats: &[],
674            });
675            let temp_view = temp.create_view(&Default::default());
676            temp_textures.push(temp);
677            temp_views.push(temp_view);
678
679            gaussian_textures.push(octave_gaussian);
680            gaussian_views.push(octave_gaussian_views);
681            dog_textures.push(octave_dog);
682            dog_views.push(octave_dog_views);
683
684            w /= 2;
685            h /= 2;
686        }
687
688        // Keypoint buffers
689        let max_keypoints = self.config.max_keypoints as u64;
690
691        let keypoint_counter = self.device.create_buffer(&wgpu::BufferDescriptor {
692            label: Some("Keypoint Counter"),
693            size: 4,
694            usage: wgpu::BufferUsages::STORAGE
695                | wgpu::BufferUsages::COPY_DST
696                | wgpu::BufferUsages::COPY_SRC,
697            mapped_at_creation: false,
698        });
699
700        let keypoints = self.device.create_buffer(&wgpu::BufferDescriptor {
701            label: Some("Keypoints"),
702            size: max_keypoints * 32, // x, y, sigma, response, octave, scale, pad, pad
703            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
704            mapped_at_creation: false,
705        });
706
707        let oriented_keypoint_counter = self.device.create_buffer(&wgpu::BufferDescriptor {
708            label: Some("Oriented Keypoint Counter"),
709            size: 4,
710            usage: wgpu::BufferUsages::STORAGE
711                | wgpu::BufferUsages::COPY_DST
712                | wgpu::BufferUsages::COPY_SRC,
713            mapped_at_creation: false,
714        });
715
716        let oriented_keypoints = self.device.create_buffer(&wgpu::BufferDescriptor {
717            label: Some("Oriented Keypoints"),
718            size: max_keypoints * 2 * 16, // x, y, sigma, angle
719            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
720            mapped_at_creation: false,
721        });
722
723        let descriptors = self.device.create_buffer(&wgpu::BufferDescriptor {
724            label: Some("Descriptors"),
725            size: max_keypoints * 2 * 128, // 128 bytes per descriptor
726            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
727            mapped_at_creation: false,
728        });
729
730        // Indirect dispatch buffers (12 bytes each: x, y, z as u32)
731        let orientation_indirect = self.device.create_buffer(&wgpu::BufferDescriptor {
732            label: Some("Orientation Indirect"),
733            size: 12,
734            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::INDIRECT,
735            mapped_at_creation: false,
736        });
737
738        let descriptor_indirect = self.device.create_buffer(&wgpu::BufferDescriptor {
739            label: Some("Descriptor Indirect"),
740            size: 12,
741            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::INDIRECT,
742            mapped_at_creation: false,
743        });
744
745        // Readback buffers
746        let readback_counters = self.device.create_buffer(&wgpu::BufferDescriptor {
747            label: Some("Readback Counters"),
748            size: 8,
749            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
750            mapped_at_creation: false,
751        });
752
753        let readback_keypoints = self.device.create_buffer(&wgpu::BufferDescriptor {
754            label: Some("Readback Keypoints"),
755            size: max_keypoints * 2 * 16,
756            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
757            mapped_at_creation: false,
758        });
759
760        let readback_descriptors = self.device.create_buffer(&wgpu::BufferDescriptor {
761            label: Some("Readback Descriptors"),
762            size: max_keypoints * 2 * 128,
763            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
764            mapped_at_creation: false,
765        });
766
767        self.resources = Some(GpuResources {
768            width,
769            height,
770            gaussian_textures,
771            gaussian_views,
772            dog_textures,
773            dog_views,
774            temp_textures,
775            temp_views,
776            keypoint_counter,
777            keypoints,
778            oriented_keypoint_counter,
779            oriented_keypoints,
780            descriptors,
781            orientation_indirect,
782            descriptor_indirect,
783            readback_counters,
784            readback_keypoints,
785            readback_descriptors,
786        });
787    }
788
789    pub async fn detect(
790        &mut self,
791        image: &[u8],
792        width: u32,
793        height: u32,
794    ) -> Result<(Vec<KeyPoint>, Vec<[u8; 128]>), Box<dyn std::error::Error>> {
795        let profile = std::env::var("SIFT_PROFILE").is_ok();
796        let total_start = web_time::Instant::now();
797
798        // Ensure resources
799        let t0 = web_time::Instant::now();
800        self.ensure_resources(width, height);
801        if profile {
802            eprintln!("  [GPU V2] Resource setup: {:?}", t0.elapsed());
803        }
804
805        let _res = self.resources.as_ref().unwrap();
806
807        // Upload image to first Gaussian texture
808        let t1 = web_time::Instant::now();
809        self.upload_image(image, width, height)?;
810        if profile {
811            eprintln!("  [GPU V2] Upload: {:?}", t1.elapsed());
812        }
813
814        // Build Gaussian pyramid on GPU
815        let t2 = web_time::Instant::now();
816        self.build_gaussian_pyramid(width, height)?;
817        if profile {
818            eprintln!("  [GPU V2] Gaussian pyramid: {:?}", t2.elapsed());
819        }
820
821        // Compute DoG
822        let t3 = web_time::Instant::now();
823        self.compute_dog(width, height)?;
824        if profile {
825            eprintln!("  [GPU V2] DoG: {:?}", t3.elapsed());
826        }
827
828        // Detect extrema
829        let t4 = web_time::Instant::now();
830        self.detect_extrema(width, height)?;
831        if profile {
832            eprintln!("  [GPU V2] Extrema: {:?}", t4.elapsed());
833        }
834
835        // Orientation assignment
836        let t5 = web_time::Instant::now();
837        self.compute_orientation(width, height)?;
838        if profile {
839            eprintln!("  [GPU V2] Orientation: {:?}", t5.elapsed());
840        }
841
842        // Descriptors
843        let t6 = web_time::Instant::now();
844        self.compute_descriptors(width, height)?;
845        if profile {
846            eprintln!("  [GPU V2] Descriptors: {:?}", t6.elapsed());
847        }
848
849        // Readback
850        let t7 = web_time::Instant::now();
851        let result = self.readback_results().await?;
852        if profile {
853            eprintln!("  [GPU V2] Readback: {:?}", t7.elapsed());
854            eprintln!("  [GPU V2] Total: {:?}", total_start.elapsed());
855        }
856
857        Ok(result)
858    }
859
860    fn upload_image(
861        &self,
862        image: &[u8],
863        width: u32,
864        height: u32,
865    ) -> Result<(), Box<dyn std::error::Error>> {
866        let res = self.resources.as_ref().unwrap();
867
868        // Convert u8 to f32
869        let image_f32: Vec<f32> = image.iter().map(|&p| p as f32 / 255.0).collect();
870        let bytes: Vec<u8> = image_f32.iter().flat_map(|f| f.to_le_bytes()).collect();
871
872        self.queue.write_texture(
873            wgpu::TexelCopyTextureInfo {
874                texture: &res.gaussian_textures[0][0],
875                mip_level: 0,
876                origin: wgpu::Origin3d::ZERO,
877                aspect: wgpu::TextureAspect::All,
878            },
879            &bytes,
880            wgpu::TexelCopyBufferLayout {
881                offset: 0,
882                bytes_per_row: Some(width * 4),
883                rows_per_image: Some(height),
884            },
885            wgpu::Extent3d {
886                width,
887                height,
888                depth_or_array_layers: 1,
889            },
890        );
891
892        Ok(())
893    }
894
895    /// Helper: create a uniform buffer with initial data
896    fn create_uniform_buffer(&self, data: &[u8], label: &str) -> wgpu::Buffer {
897        use wgpu::util::DeviceExt;
898        self.device
899            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
900                label: Some(label),
901                contents: data,
902                usage: wgpu::BufferUsages::UNIFORM,
903            })
904    }
905
906    /// Helper: create a storage buffer with initial data
907    fn create_storage_buffer(&self, data: &[u8], label: &str) -> wgpu::Buffer {
908        use wgpu::util::DeviceExt;
909        self.device
910            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
911                label: Some(label),
912                contents: data,
913                usage: wgpu::BufferUsages::STORAGE,
914            })
915    }
916
917    fn build_gaussian_pyramid(
918        &self,
919        width: u32,
920        height: u32,
921    ) -> Result<(), Box<dyn std::error::Error>> {
922        let res = self.resources.as_ref().unwrap();
923        let scales_per_octave = self.config.scales_per_octave + 3;
924
925        // k is the multiplier between scales
926        let k = 2.0_f32.powf(1.0 / self.config.scales_per_octave as f32);
927
928        let mut encoder = self
929            .device
930            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
931                label: Some("Gaussian Pyramid Encoder"),
932            });
933
934        let mut w = width;
935        let mut h = height;
936
937        for octave in 0..res.gaussian_textures.len() {
938            if w < 8 || h < 8 {
939                break;
940            }
941
942            for s in 0..scales_per_octave as usize {
943                // Compute sigma for this scale
944                let sigma = if s == 0 && octave == 0 {
945                    // First scale of first octave: blur from assumed blur to base_sigma
946                    let assumed_blur = 0.5f32;
947                    if self.config.base_sigma > assumed_blur {
948                        (self.config.base_sigma * self.config.base_sigma
949                            - assumed_blur * assumed_blur)
950                            .sqrt()
951                    } else {
952                        0.0
953                    }
954                } else if s == 0 {
955                    // First scale of other octaves: already blurred from downsample
956                    0.0
957                } else {
958                    // Incremental blur
959                    let sigma_prev = self.config.base_sigma * k.powi((s - 1) as i32);
960                    let sigma_curr = self.config.base_sigma * k.powi(s as i32);
961                    (sigma_curr * sigma_curr - sigma_prev * sigma_prev).sqrt()
962                };
963
964                if sigma < 0.1 {
965                    // Skip blur, just copy
966                    if s > 0 {
967                        encoder.copy_texture_to_texture(
968                            wgpu::TexelCopyTextureInfo {
969                                texture: &res.gaussian_textures[octave][s - 1],
970                                mip_level: 0,
971                                origin: wgpu::Origin3d::ZERO,
972                                aspect: wgpu::TextureAspect::All,
973                            },
974                            wgpu::TexelCopyTextureInfo {
975                                texture: &res.gaussian_textures[octave][s],
976                                mip_level: 0,
977                                origin: wgpu::Origin3d::ZERO,
978                                aspect: wgpu::TextureAspect::All,
979                            },
980                            wgpu::Extent3d {
981                                width: w,
982                                height: h,
983                                depth_or_array_layers: 1,
984                            },
985                        );
986                    }
987                    continue;
988                }
989
990                // Build kernel
991                let radius = (sigma * 2.5).ceil() as i32;
992                let kernel_size = (2 * radius + 1) as usize;
993                let mut kernel = vec![0.0f32; kernel_size];
994                let mut sum = 0.0f32;
995                let two_sigma_sq = 2.0 * sigma * sigma;
996
997                for (i, kv) in kernel.iter_mut().enumerate() {
998                    let x = (i as i32 - radius) as f32;
999                    *kv = (-x * x / two_sigma_sq).exp();
1000                    sum += *kv;
1001                }
1002                for kv in kernel.iter_mut() {
1003                    *kv /= sum;
1004                }
1005
1006                // Create kernel buffer for this pass
1007                let kernel_bytes: Vec<u8> = kernel.iter().flat_map(|f| f.to_le_bytes()).collect();
1008                let kernel_buffer = self.create_storage_buffer(&kernel_bytes, "Kernel Buffer");
1009
1010                // Source texture
1011                let src_view = if s == 0 {
1012                    if octave == 0 {
1013                        &res.gaussian_views[0][0]
1014                    } else {
1015                        &res.gaussian_views[octave][0]
1016                    }
1017                } else {
1018                    &res.gaussian_views[octave][s - 1]
1019                };
1020
1021                // Horizontal pass - create params buffer for this specific pass
1022                let params_h = [w, h, radius as u32, 0u32]; // direction 0 = horizontal
1023                let params_h_bytes: Vec<u8> =
1024                    params_h.iter().flat_map(|v| v.to_le_bytes()).collect();
1025                let params_h_buffer = self.create_uniform_buffer(&params_h_bytes, "Blur H Params");
1026
1027                let blur_h_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1028                    label: Some("Blur H BG"),
1029                    layout: &self.blur_bind_group_layout,
1030                    entries: &[
1031                        wgpu::BindGroupEntry {
1032                            binding: 0,
1033                            resource: wgpu::BindingResource::TextureView(src_view),
1034                        },
1035                        wgpu::BindGroupEntry {
1036                            binding: 1,
1037                            resource: wgpu::BindingResource::Sampler(&self.linear_sampler),
1038                        },
1039                        wgpu::BindGroupEntry {
1040                            binding: 2,
1041                            resource: wgpu::BindingResource::TextureView(&res.temp_views[octave]),
1042                        },
1043                        wgpu::BindGroupEntry {
1044                            binding: 3,
1045                            resource: params_h_buffer.as_entire_binding(),
1046                        },
1047                        wgpu::BindGroupEntry {
1048                            binding: 4,
1049                            resource: kernel_buffer.as_entire_binding(),
1050                        },
1051                    ],
1052                });
1053
1054                {
1055                    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1056                        label: Some("Blur H Pass"),
1057                        timestamp_writes: None,
1058                    });
1059                    pass.set_pipeline(&self.blur_h_pipeline);
1060                    pass.set_bind_group(0, &blur_h_bg, &[]);
1061                    pass.dispatch_workgroups((w + 15) / 16, (h + 15) / 16, 1);
1062                }
1063
1064                // Vertical pass - create params buffer for this specific pass
1065                let params_v = [w, h, radius as u32, 1u32]; // direction 1 = vertical
1066                let params_v_bytes: Vec<u8> =
1067                    params_v.iter().flat_map(|v| v.to_le_bytes()).collect();
1068                let params_v_buffer = self.create_uniform_buffer(&params_v_bytes, "Blur V Params");
1069
1070                let blur_v_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1071                    label: Some("Blur V BG"),
1072                    layout: &self.blur_bind_group_layout,
1073                    entries: &[
1074                        wgpu::BindGroupEntry {
1075                            binding: 0,
1076                            resource: wgpu::BindingResource::TextureView(&res.temp_views[octave]),
1077                        },
1078                        wgpu::BindGroupEntry {
1079                            binding: 1,
1080                            resource: wgpu::BindingResource::Sampler(&self.linear_sampler),
1081                        },
1082                        wgpu::BindGroupEntry {
1083                            binding: 2,
1084                            resource: wgpu::BindingResource::TextureView(
1085                                &res.gaussian_views[octave][s],
1086                            ),
1087                        },
1088                        wgpu::BindGroupEntry {
1089                            binding: 3,
1090                            resource: params_v_buffer.as_entire_binding(),
1091                        },
1092                        wgpu::BindGroupEntry {
1093                            binding: 4,
1094                            resource: kernel_buffer.as_entire_binding(),
1095                        },
1096                    ],
1097                });
1098
1099                {
1100                    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1101                        label: Some("Blur V Pass"),
1102                        timestamp_writes: None,
1103                    });
1104                    pass.set_pipeline(&self.blur_v_pipeline);
1105                    pass.set_bind_group(0, &blur_v_bg, &[]);
1106                    pass.dispatch_workgroups((w + 15) / 16, (h + 15) / 16, 1);
1107                }
1108            }
1109
1110            // Downsample for next octave
1111            if octave + 1 < res.gaussian_textures.len() {
1112                let next_w = w / 2;
1113                let next_h = h / 2;
1114
1115                // Source from scale (scales_per_octave - 3) which has 2x base blur
1116                let src_scale = (scales_per_octave as usize).saturating_sub(3);
1117
1118                let params = [w, h, next_w, next_h];
1119                let params_bytes: Vec<u8> = params.iter().flat_map(|v| v.to_le_bytes()).collect();
1120                let params_buffer = self.create_uniform_buffer(&params_bytes, "Downsample Params");
1121
1122                let ds_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1123                    label: Some("Downsample BG"),
1124                    layout: &self.downsample_bind_group_layout,
1125                    entries: &[
1126                        wgpu::BindGroupEntry {
1127                            binding: 0,
1128                            resource: wgpu::BindingResource::TextureView(
1129                                &res.gaussian_views[octave][src_scale],
1130                            ),
1131                        },
1132                        wgpu::BindGroupEntry {
1133                            binding: 1,
1134                            resource: wgpu::BindingResource::Sampler(&self.linear_sampler),
1135                        },
1136                        wgpu::BindGroupEntry {
1137                            binding: 2,
1138                            resource: wgpu::BindingResource::TextureView(
1139                                &res.gaussian_views[octave + 1][0],
1140                            ),
1141                        },
1142                        wgpu::BindGroupEntry {
1143                            binding: 3,
1144                            resource: params_buffer.as_entire_binding(),
1145                        },
1146                    ],
1147                });
1148
1149                {
1150                    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1151                        label: Some("Downsample Pass"),
1152                        timestamp_writes: None,
1153                    });
1154                    pass.set_pipeline(&self.downsample_pipeline);
1155                    pass.set_bind_group(0, &ds_bg, &[]);
1156                    pass.dispatch_workgroups((next_w + 15) / 16, (next_h + 15) / 16, 1);
1157                }
1158            }
1159
1160            w /= 2;
1161            h /= 2;
1162        }
1163
1164        self.queue.submit(Some(encoder.finish()));
1165        // No poll here - let GPU work continue
1166        Ok(())
1167    }
1168
1169    fn compute_dog(&self, width: u32, height: u32) -> Result<(), Box<dyn std::error::Error>> {
1170        let res = self.resources.as_ref().unwrap();
1171
1172        let mut encoder = self
1173            .device
1174            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1175                label: Some("DoG Encoder"),
1176            });
1177
1178        let mut w = width;
1179        let mut h = height;
1180
1181        for octave in 0..res.dog_textures.len() {
1182            if w < 8 || h < 8 {
1183                break;
1184            }
1185
1186            for d in 0..res.dog_views[octave].len() {
1187                let params = [w, h, 0u32, 0u32];
1188                let params_bytes: Vec<u8> = params.iter().flat_map(|v| v.to_le_bytes()).collect();
1189                let params_buffer = self.create_uniform_buffer(&params_bytes, "DoG Params");
1190
1191                let dog_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1192                    label: Some("DoG BG"),
1193                    layout: &self.dog_bind_group_layout,
1194                    entries: &[
1195                        wgpu::BindGroupEntry {
1196                            binding: 0,
1197                            resource: wgpu::BindingResource::TextureView(
1198                                &res.gaussian_views[octave][d + 1],
1199                            ),
1200                        },
1201                        wgpu::BindGroupEntry {
1202                            binding: 1,
1203                            resource: wgpu::BindingResource::TextureView(
1204                                &res.gaussian_views[octave][d],
1205                            ),
1206                        },
1207                        wgpu::BindGroupEntry {
1208                            binding: 2,
1209                            resource: wgpu::BindingResource::TextureView(&res.dog_views[octave][d]),
1210                        },
1211                        wgpu::BindGroupEntry {
1212                            binding: 3,
1213                            resource: params_buffer.as_entire_binding(),
1214                        },
1215                    ],
1216                });
1217
1218                {
1219                    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1220                        label: Some("DoG Pass"),
1221                        timestamp_writes: None,
1222                    });
1223                    pass.set_pipeline(&self.dog_pipeline);
1224                    pass.set_bind_group(0, &dog_bg, &[]);
1225                    pass.dispatch_workgroups((w + 15) / 16, (h + 15) / 16, 1);
1226                }
1227            }
1228
1229            w /= 2;
1230            h /= 2;
1231        }
1232
1233        self.queue.submit(Some(encoder.finish()));
1234        // No poll here - let GPU work continue
1235        Ok(())
1236    }
1237
1238    fn detect_extrema(&self, width: u32, height: u32) -> Result<(), Box<dyn std::error::Error>> {
1239        let res = self.resources.as_ref().unwrap();
1240
1241        // Clear counter
1242        self.queue.write_buffer(&res.keypoint_counter, 0, &[0u8; 4]);
1243
1244        let mut encoder = self
1245            .device
1246            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1247                label: Some("Extrema Encoder"),
1248            });
1249
1250        let mut w = width;
1251        let mut h = height;
1252        let k = 2.0_f32.powf(1.0 / self.config.scales_per_octave as f32);
1253
1254        for octave in 0..res.dog_textures.len() {
1255            if w < 8 || h < 8 {
1256                break;
1257            }
1258
1259            // Need 3 adjacent DoG scales for extrema detection
1260            for d in 1..(res.dog_views[octave].len() - 1) {
1261                let sigma = self.config.base_sigma * k.powi(d as i32) * (1 << octave) as f32;
1262
1263                #[repr(C)]
1264                struct ExtremaParams {
1265                    width: u32,
1266                    height: u32,
1267                    octave: u32,
1268                    scale: u32,
1269                    contrast_threshold: f32,
1270                    edge_threshold: f32,
1271                    sigma: f32,
1272                    _pad: u32,
1273                }
1274
1275                let params = ExtremaParams {
1276                    width: w,
1277                    height: h,
1278                    octave: octave as u32,
1279                    scale: d as u32,
1280                    // SIFT paper: threshold = 0.04 / num_intervals
1281                    contrast_threshold: self.config.contrast_threshold
1282                        / self.config.scales_per_octave as f32,
1283                    edge_threshold: self.config.edge_threshold,
1284                    sigma,
1285                    _pad: 0,
1286                };
1287
1288                let params_bytes = unsafe {
1289                    std::slice::from_raw_parts(
1290                        &params as *const _ as *const u8,
1291                        std::mem::size_of::<ExtremaParams>(),
1292                    )
1293                };
1294                let params_buffer = self.create_uniform_buffer(params_bytes, "Extrema Params");
1295
1296                let extrema_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1297                    label: Some("Extrema BG"),
1298                    layout: &self.extrema_bind_group_layout,
1299                    entries: &[
1300                        wgpu::BindGroupEntry {
1301                            binding: 0,
1302                            resource: wgpu::BindingResource::TextureView(
1303                                &res.dog_views[octave][d - 1],
1304                            ),
1305                        },
1306                        wgpu::BindGroupEntry {
1307                            binding: 1,
1308                            resource: wgpu::BindingResource::TextureView(&res.dog_views[octave][d]),
1309                        },
1310                        wgpu::BindGroupEntry {
1311                            binding: 2,
1312                            resource: wgpu::BindingResource::TextureView(
1313                                &res.dog_views[octave][d + 1],
1314                            ),
1315                        },
1316                        wgpu::BindGroupEntry {
1317                            binding: 3,
1318                            resource: params_buffer.as_entire_binding(),
1319                        },
1320                        wgpu::BindGroupEntry {
1321                            binding: 4,
1322                            resource: res.keypoint_counter.as_entire_binding(),
1323                        },
1324                        wgpu::BindGroupEntry {
1325                            binding: 5,
1326                            resource: res.keypoints.as_entire_binding(),
1327                        },
1328                    ],
1329                });
1330
1331                {
1332                    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1333                        label: Some("Extrema Pass"),
1334                        timestamp_writes: None,
1335                    });
1336                    pass.set_pipeline(&self.extrema_pipeline);
1337                    pass.set_bind_group(0, &extrema_bg, &[]);
1338                    pass.dispatch_workgroups((w + 15) / 16, (h + 15) / 16, 1);
1339                }
1340            }
1341
1342            w /= 2;
1343            h /= 2;
1344        }
1345
1346        self.queue.submit(Some(encoder.finish()));
1347        // No poll - next stage uses indirect dispatch
1348
1349        Ok(())
1350    }
1351
1352    fn compute_orientation(
1353        &self,
1354        width: u32,
1355        height: u32,
1356    ) -> Result<(), Box<dyn std::error::Error>> {
1357        let res = self.resources.as_ref().unwrap();
1358
1359        // Clear output counter
1360        self.queue
1361            .write_buffer(&res.oriented_keypoint_counter, 0, &[0u8; 4]);
1362
1363        // Create bind group for prepare indirect dispatch
1364        let prepare_indirect_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1365            label: Some("Prepare Orient Indirect BG"),
1366            layout: &self.prepare_indirect_bind_group_layout,
1367            entries: &[
1368                wgpu::BindGroupEntry {
1369                    binding: 0,
1370                    resource: res.keypoint_counter.as_entire_binding(),
1371                },
1372                wgpu::BindGroupEntry {
1373                    binding: 1,
1374                    resource: res.orientation_indirect.as_entire_binding(),
1375                },
1376            ],
1377        });
1378
1379        // Max keypoints for params (shader will check bounds)
1380        #[repr(C)]
1381        struct OrientParams {
1382            width: u32,
1383            height: u32,
1384            octave: u32,
1385            num_keypoints: u32,
1386        }
1387
1388        let params = OrientParams {
1389            width,
1390            height,
1391            octave: 0,
1392            num_keypoints: 32768, // Max, shader checks actual count
1393        };
1394
1395        let params_bytes = unsafe {
1396            std::slice::from_raw_parts(
1397                &params as *const _ as *const u8,
1398                std::mem::size_of::<OrientParams>(),
1399            )
1400        };
1401        let params_buffer = self.create_uniform_buffer(params_bytes, "Orientation Params");
1402
1403        // Use the first scale of first octave for gradient computation
1404        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1405            label: Some("Orientation BG"),
1406            layout: &self.orientation_pipeline.get_bind_group_layout(0),
1407            entries: &[
1408                wgpu::BindGroupEntry {
1409                    binding: 0,
1410                    resource: wgpu::BindingResource::TextureView(&res.gaussian_views[0][0]),
1411                },
1412                wgpu::BindGroupEntry {
1413                    binding: 1,
1414                    resource: params_buffer.as_entire_binding(),
1415                },
1416                wgpu::BindGroupEntry {
1417                    binding: 2,
1418                    resource: res.keypoints.as_entire_binding(),
1419                },
1420                wgpu::BindGroupEntry {
1421                    binding: 3,
1422                    resource: res.oriented_keypoint_counter.as_entire_binding(),
1423                },
1424                wgpu::BindGroupEntry {
1425                    binding: 4,
1426                    resource: res.oriented_keypoints.as_entire_binding(),
1427                },
1428            ],
1429        });
1430
1431        let mut encoder = self
1432            .device
1433            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1434                label: Some("Orientation Encoder"),
1435            });
1436
1437        // First: prepare indirect dispatch buffer from keypoint counter
1438        {
1439            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1440                label: Some("Prepare Orient Indirect Pass"),
1441                timestamp_writes: None,
1442            });
1443            pass.set_pipeline(&self.prepare_orient_indirect_pipeline);
1444            pass.set_bind_group(0, &prepare_indirect_bg, &[]);
1445            pass.dispatch_workgroups(1, 1, 1);
1446        }
1447
1448        // Then: run orientation using indirect dispatch
1449        {
1450            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1451                label: Some("Orientation Pass"),
1452                timestamp_writes: None,
1453            });
1454            pass.set_pipeline(&self.orientation_pipeline);
1455            pass.set_bind_group(0, &bind_group, &[]);
1456            pass.dispatch_workgroups_indirect(&res.orientation_indirect, 0);
1457        }
1458
1459        self.queue.submit(Some(encoder.finish()));
1460        // No poll - descriptor stage uses indirect dispatch too
1461
1462        Ok(())
1463    }
1464
1465    fn compute_descriptors(
1466        &self,
1467        width: u32,
1468        height: u32,
1469    ) -> Result<(), Box<dyn std::error::Error>> {
1470        let res = self.resources.as_ref().unwrap();
1471
1472        // Create bind group for prepare indirect dispatch
1473        let prepare_indirect_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1474            label: Some("Prepare Desc Indirect BG"),
1475            layout: &self.prepare_indirect_bind_group_layout,
1476            entries: &[
1477                wgpu::BindGroupEntry {
1478                    binding: 0,
1479                    resource: res.oriented_keypoint_counter.as_entire_binding(),
1480                },
1481                wgpu::BindGroupEntry {
1482                    binding: 1,
1483                    resource: res.descriptor_indirect.as_entire_binding(),
1484                },
1485            ],
1486        });
1487
1488        #[repr(C)]
1489        struct DescParams {
1490            width: u32,
1491            height: u32,
1492            octave: u32,
1493            num_keypoints: u32,
1494        }
1495
1496        let params = DescParams {
1497            width,
1498            height,
1499            octave: 0,
1500            num_keypoints: 65536, // Max, shader checks actual count
1501        };
1502
1503        let params_bytes = unsafe {
1504            std::slice::from_raw_parts(
1505                &params as *const _ as *const u8,
1506                std::mem::size_of::<DescParams>(),
1507            )
1508        };
1509        let params_buffer = self.create_uniform_buffer(params_bytes, "Descriptor Params");
1510
1511        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1512            label: Some("Descriptor BG"),
1513            layout: &self.descriptor_pipeline.get_bind_group_layout(0),
1514            entries: &[
1515                wgpu::BindGroupEntry {
1516                    binding: 0,
1517                    resource: wgpu::BindingResource::TextureView(&res.gaussian_views[0][0]),
1518                },
1519                wgpu::BindGroupEntry {
1520                    binding: 1,
1521                    resource: params_buffer.as_entire_binding(),
1522                },
1523                wgpu::BindGroupEntry {
1524                    binding: 2,
1525                    resource: res.oriented_keypoints.as_entire_binding(),
1526                },
1527                wgpu::BindGroupEntry {
1528                    binding: 3,
1529                    resource: res.descriptors.as_entire_binding(),
1530                },
1531            ],
1532        });
1533
1534        let mut encoder = self
1535            .device
1536            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1537                label: Some("Descriptor Encoder"),
1538            });
1539
1540        // First: prepare indirect dispatch buffer from oriented keypoint counter
1541        {
1542            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1543                label: Some("Prepare Desc Indirect Pass"),
1544                timestamp_writes: None,
1545            });
1546            pass.set_pipeline(&self.prepare_desc_indirect_pipeline);
1547            pass.set_bind_group(0, &prepare_indirect_bg, &[]);
1548            pass.dispatch_workgroups(1, 1, 1);
1549        }
1550
1551        // Then: run descriptors using indirect dispatch
1552        {
1553            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1554                label: Some("Descriptor Pass"),
1555                timestamp_writes: None,
1556            });
1557            pass.set_pipeline(&self.descriptor_pipeline);
1558            pass.set_bind_group(0, &bind_group, &[]);
1559            pass.dispatch_workgroups_indirect(&res.descriptor_indirect, 0);
1560        }
1561
1562        self.queue.submit(Some(encoder.finish()));
1563        // No poll - readback will sync
1564
1565        Ok(())
1566    }
1567
1568    async fn readback_results(
1569        &self,
1570    ) -> Result<(Vec<KeyPoint>, Vec<[u8; 128]>), Box<dyn std::error::Error>> {
1571        let res = self.resources.as_ref().unwrap();
1572
1573        // Copy counter and all data in ONE submit
1574        let max_keypoints = self.config.max_keypoints as u64;
1575
1576        let mut encoder = self
1577            .device
1578            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1579                label: Some("Readback Encoder"),
1580            });
1581
1582        // Copy counter
1583        encoder.copy_buffer_to_buffer(
1584            &res.oriented_keypoint_counter,
1585            0,
1586            &res.readback_counters,
1587            0,
1588            4,
1589        );
1590
1591        // Copy keypoints and descriptors
1592        encoder.copy_buffer_to_buffer(
1593            &res.oriented_keypoints,
1594            0,
1595            &res.readback_keypoints,
1596            0,
1597            max_keypoints * 16,
1598        );
1599        encoder.copy_buffer_to_buffer(
1600            &res.descriptors,
1601            0,
1602            &res.readback_descriptors,
1603            0,
1604            max_keypoints * 128,
1605        );
1606
1607        self.queue.submit(Some(encoder.finish()));
1608
1609        // Map ALL buffers in parallel - single poll for everything
1610        let counter_slice = res.readback_counters.slice(..4);
1611        let kp_slice = res.readback_keypoints.slice(..(max_keypoints * 16));
1612        let desc_slice = res.readback_descriptors.slice(..(max_keypoints * 128));
1613
1614        let (tx1, rx1) = futures::channel::oneshot::channel();
1615        let (tx2, rx2) = futures::channel::oneshot::channel();
1616        let (tx3, rx3) = futures::channel::oneshot::channel();
1617
1618        counter_slice.map_async(wgpu::MapMode::Read, move |result| {
1619            let _ = tx1.send(result);
1620        });
1621        kp_slice.map_async(wgpu::MapMode::Read, move |result| {
1622            let _ = tx2.send(result);
1623        });
1624        desc_slice.map_async(wgpu::MapMode::Read, move |result| {
1625            let _ = tx3.send(result);
1626        });
1627
1628        // Ensure works are submitted and mapping process started
1629        // On native: Wait blocks until done.
1630        // On Web: Wait is no-op or poll. We need to await the channels to let event loop run.
1631        #[cfg(not(target_arch = "wasm32"))]
1632        {
1633            // Native polling usually requires Maintain::Wait but wgpu 25+ might have changed api.
1634            // For now commenting out to avoid IDE errors if Maintain is missing.
1635            // self.device.poll(wgpu::Maintain::Wait);
1636            // self.device.poll(wgpu::Maintain::Poll);
1637        }
1638
1639        #[cfg(target_arch = "wasm32")]
1640        {
1641            // Give the browser a chance to poll if needed, though await below does it implicitly
1642            // self.device.poll(wgpu::Maintain::Poll);
1643        }
1644
1645        // Await results (non-blocking yield)
1646        rx1.await??;
1647        rx2.await??;
1648        rx3.await??;
1649
1650        // Read counter first
1651        let counter_data = counter_slice.get_mapped_range();
1652        let num_keypoints = u32::from_le_bytes([
1653            counter_data[0],
1654            counter_data[1],
1655            counter_data[2],
1656            counter_data[3],
1657        ])
1658        .min(self.config.max_keypoints) as usize;
1659        drop(counter_data);
1660        res.readback_counters.unmap();
1661
1662        if num_keypoints == 0 {
1663            res.readback_keypoints.unmap();
1664            res.readback_descriptors.unmap();
1665            return Ok((Vec::new(), Vec::new()));
1666        }
1667
1668        // Read keypoints
1669        let kp_data = kp_slice.get_mapped_range();
1670        let mut keypoints = Vec::with_capacity(num_keypoints);
1671
1672        for i in 0..num_keypoints {
1673            let offset = i * 16;
1674            let x = f32::from_le_bytes([
1675                kp_data[offset],
1676                kp_data[offset + 1],
1677                kp_data[offset + 2],
1678                kp_data[offset + 3],
1679            ]);
1680            let y = f32::from_le_bytes([
1681                kp_data[offset + 4],
1682                kp_data[offset + 5],
1683                kp_data[offset + 6],
1684                kp_data[offset + 7],
1685            ]);
1686            let sigma = f32::from_le_bytes([
1687                kp_data[offset + 8],
1688                kp_data[offset + 9],
1689                kp_data[offset + 10],
1690                kp_data[offset + 11],
1691            ]);
1692            let angle = f32::from_le_bytes([
1693                kp_data[offset + 12],
1694                kp_data[offset + 13],
1695                kp_data[offset + 14],
1696                kp_data[offset + 15],
1697            ]);
1698
1699            keypoints.push(KeyPoint {
1700                x,
1701                y,
1702                size: sigma * 2.0,
1703                angle,
1704                response: 0.0,
1705                octave: 0,
1706                layer: 0,
1707            });
1708        }
1709        drop(kp_data);
1710        res.readback_keypoints.unmap();
1711
1712        // Read descriptors
1713        let desc_data = desc_slice.get_mapped_range();
1714        let mut descriptors = Vec::with_capacity(num_keypoints);
1715
1716        for i in 0..num_keypoints {
1717            let offset = i * 128;
1718            let mut desc = [0u8; 128];
1719            desc.copy_from_slice(&desc_data[offset..offset + 128]);
1720            descriptors.push(desc);
1721        }
1722        drop(desc_data);
1723        res.readback_descriptors.unmap();
1724
1725        Ok((keypoints, descriptors))
1726    }
1727}