wgpu_mipmap/backends/
compute.rs

1use crate::core::*;
2use crate::util::get_mip_extent;
3use log::warn;
4use std::collections::HashMap;
5
6/// Generates mipmaps for textures with storage usage.
7#[derive(Debug)]
8pub struct ComputeMipmapGenerator {
9    layout_cache: HashMap<wgpu::TextureFormat, wgpu::BindGroupLayout>,
10    pipeline_cache: HashMap<wgpu::TextureFormat, wgpu::ComputePipeline>,
11}
12
13impl ComputeMipmapGenerator {
14    /// Returns the texture usage `ComputeMipmapGenerator` requires for mipmap generation.
15    pub fn required_usage() -> wgpu::TextureUsage {
16        wgpu::TextureUsage::STORAGE
17    }
18
19    /// Creates a new `ComputeMipmapGenerator`. Once created, it can be used repeatedly to
20    /// generate mipmaps for any texture with format specified in `format_hints`.
21    pub fn new_with_format_hints(
22        device: &wgpu::Device,
23        format_hints: &[wgpu::TextureFormat],
24    ) -> Self {
25        let mut layout_cache = HashMap::new();
26        let mut pipeline_cache = HashMap::new();
27        for format in format_hints {
28            if let Some(module) = shader_for_format(device, format) {
29                let bind_group_layout = bind_group_layout_for_format(device, format);
30                let pipeline =
31                    compute_pipeline_for_format(device, &module, &bind_group_layout, format);
32                layout_cache.insert(*format, bind_group_layout);
33                pipeline_cache.insert(*format, pipeline);
34            } else {
35                warn!(
36                    "ComputeMipmapGenerator does not support requested format {:?}",
37                    format
38                );
39                continue;
40            }
41        }
42        Self {
43            layout_cache,
44            pipeline_cache,
45        }
46    }
47}
48
49impl MipmapGenerator for ComputeMipmapGenerator {
50    fn generate(
51        &self,
52        device: &wgpu::Device,
53        encoder: &mut wgpu::CommandEncoder,
54        texture: &wgpu::Texture,
55        texture_descriptor: &wgpu::TextureDescriptor,
56    ) -> Result<(), Error> {
57        // Texture width and height must be a power of 2
58        if !texture_descriptor.size.width.is_power_of_two()
59            || !texture_descriptor.size.height.is_power_of_two()
60        {
61            return Err(Error::NpotTexture);
62        }
63        // Texture dimension must be 2D
64        if texture_descriptor.dimension != wgpu::TextureDimension::D2 {
65            return Err(Error::UnsupportedDimension(texture_descriptor.dimension));
66        }
67        if !texture_descriptor.usage.contains(Self::required_usage()) {
68            return Err(Error::UnsupportedUsage(texture_descriptor.usage));
69        }
70
71        let layout = self
72            .layout_cache
73            .get(&texture_descriptor.format)
74            .ok_or(Error::UnknownFormat(texture_descriptor.format))?;
75        let pipeline = self
76            .pipeline_cache
77            .get(&texture_descriptor.format)
78            .ok_or(Error::UnknownFormat(texture_descriptor.format))?;
79
80        let mip_count = texture_descriptor.mip_level_count;
81        // TODO: Can we create the views every call?
82        let views = (0..mip_count)
83            .map(|base_mip_level| {
84                texture.create_view(&wgpu::TextureViewDescriptor {
85                    label: None,
86                    format: None,
87                    dimension: None,
88                    aspect: wgpu::TextureAspect::All,
89                    base_mip_level,
90                    level_count: std::num::NonZeroU32::new(1),
91                    array_layer_count: None,
92                    base_array_layer: 0,
93                })
94            })
95            .collect::<Vec<_>>();
96        // Now dispatch the compute pipeline for each mip level
97        // TODO: Likely need more flexibility here
98        // - The compute shaders must have matching local_size_x and local_size_y values
99        // - When the image size is less than 32x32, more work is performed than required
100        let x_work_group_count = 32;
101        let y_work_group_count = 32;
102        for mip in 1..mip_count as usize {
103            let src_view = &views[mip - 1];
104            let dst_view = &views[mip];
105            let mip_ext = get_mip_extent(&texture_descriptor.size, mip as u32);
106            let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
107                label: None,
108                layout,
109                entries: &[
110                    wgpu::BindGroupEntry {
111                        binding: 0,
112                        resource: wgpu::BindingResource::TextureView(&src_view),
113                    },
114                    wgpu::BindGroupEntry {
115                        binding: 1,
116                        resource: wgpu::BindingResource::TextureView(&dst_view),
117                    },
118                ],
119            });
120            let mut pass = encoder.begin_compute_pass();
121            pass.set_pipeline(pipeline);
122            pass.set_bind_group(0, &bind_group, &[]);
123            pass.dispatch(
124                (mip_ext.width / x_work_group_count).max(1),
125                (mip_ext.height / y_work_group_count).max(1),
126                1,
127            );
128        }
129        Ok(())
130    }
131}
132
133fn shader_for_format(
134    device: &wgpu::Device,
135    format: &wgpu::TextureFormat,
136) -> Option<wgpu::ShaderModule> {
137    use wgpu::TextureFormat;
138    let s = |d| Some(device.create_shader_module(wgpu::util::make_spirv(d)));
139    match format {
140        TextureFormat::R8Unorm => s(include_bytes!("shaders/box_r8.comp.spv")),
141        TextureFormat::R8Snorm => s(include_bytes!("shaders/box_r8_snorm.comp.spv")),
142        TextureFormat::R16Float => s(include_bytes!("shaders/box_r16f.comp.spv")),
143        TextureFormat::Rg8Unorm => s(include_bytes!("shaders/box_rg8.comp.spv")),
144        TextureFormat::Rg8Snorm => s(include_bytes!("shaders/box_rg8_snorm.comp.spv")),
145        TextureFormat::R32Float => s(include_bytes!("shaders/box_r32f.comp.spv")),
146        TextureFormat::Rg16Float => s(include_bytes!("shaders/box_rg16f.comp.spv")),
147        TextureFormat::Rgba8Unorm => s(include_bytes!("shaders/box_rgba8.comp.spv")),
148        TextureFormat::Rgba8UnormSrgb | TextureFormat::Bgra8UnormSrgb => {
149            // On MacOS, my GPUFamily2 v1 capable GPU
150            // seems to perform the srgb -> linear before I load it
151            // in the shader, but expects me to perform the linear -> srgb
152            // conversion before storing.
153            #[cfg(target_os = "macos")]
154            {
155                s(include_bytes!("shaders/box_srgb_macos.comp.spv"))
156            }
157            // On  Vulkan (and DX12?), the implementation does not perform
158            // any conversion, so this shader handles it all
159            #[cfg(not(target_os = "macos"))]
160            {
161                s(include_bytes!("shaders/box_srgb.comp.spv"))
162            }
163        }
164        TextureFormat::Rgba8Snorm => s(include_bytes!("shaders/box_rgba8_snorm.comp.spv")),
165        TextureFormat::Bgra8Unorm => s(include_bytes!("shaders/box_rgba8.comp.spv")),
166        TextureFormat::Rgb10a2Unorm => s(include_bytes!("shaders/box_rgb10_a2.comp.spv")),
167        TextureFormat::Rg11b10Float => s(include_bytes!("shaders/box_r11f_g11f_b10f.comp.spv")),
168        TextureFormat::Rg32Float => s(include_bytes!("shaders/box_rg32f.comp.spv")),
169        TextureFormat::Rgba16Float => s(include_bytes!("shaders/box_rgba16f.comp.spv")),
170        TextureFormat::Rgba32Float => s(include_bytes!("shaders/box_rgba32f.comp.spv")),
171        _ => None,
172    }
173}
174
175fn bind_group_layout_for_format(
176    device: &wgpu::Device,
177    format: &wgpu::TextureFormat,
178) -> wgpu::BindGroupLayout {
179    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
180        label: None,
181        entries: &[
182            wgpu::BindGroupLayoutEntry {
183                binding: 0,
184                visibility: wgpu::ShaderStage::COMPUTE,
185                ty: wgpu::BindingType::StorageTexture {
186                    dimension: wgpu::TextureViewDimension::D2,
187                    format: *format,
188                    readonly: true,
189                },
190                count: None,
191            },
192            wgpu::BindGroupLayoutEntry {
193                binding: 1,
194                visibility: wgpu::ShaderStage::COMPUTE,
195                ty: wgpu::BindingType::StorageTexture {
196                    dimension: wgpu::TextureViewDimension::D2,
197                    format: *format,
198                    readonly: false,
199                },
200                count: None,
201            },
202        ],
203    })
204}
205
206fn compute_pipeline_for_format(
207    device: &wgpu::Device,
208    module: &wgpu::ShaderModule,
209    bind_group_layout: &wgpu::BindGroupLayout,
210    format: &wgpu::TextureFormat,
211) -> wgpu::ComputePipeline {
212    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
213        label: None,
214        bind_group_layouts: &[&bind_group_layout],
215        push_constant_ranges: &[],
216    });
217    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
218        label: Some(&format!("wgpu-mipmap-compute-pipeline-{:?}", format)),
219        layout: Some(&pipeline_layout),
220        compute_stage: wgpu::ProgrammableStageDescriptor {
221            module: &module,
222            entry_point: "main",
223        },
224    })
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::util::*;
231
232    fn init() {
233        let _ = env_logger::builder().is_test(true).try_init();
234    }
235
236    #[allow(dead_code)]
237    async fn generate_and_copy_to_cpu_compute(
238        buffer: &[u8],
239        texture_descriptor: &wgpu::TextureDescriptor<'_>,
240    ) -> Result<Vec<MipBuffer>, Error> {
241        let (_instance, _adaptor, device, queue) = wgpu_setup().await;
242        let generator = crate::backends::ComputeMipmapGenerator::new_with_format_hints(
243            &device,
244            &[texture_descriptor.format],
245        );
246        Ok(
247            generate_and_copy_to_cpu(&device, &queue, &generator, buffer, texture_descriptor)
248                .await?,
249        )
250    }
251
252    async fn generate_test(texture_descriptor: &wgpu::TextureDescriptor<'_>) -> Result<(), Error> {
253        let (_instance, _adapter, device, _queue) = wgpu_setup().await;
254        let generator =
255            ComputeMipmapGenerator::new_with_format_hints(&device, &[texture_descriptor.format]);
256        let texture = device.create_texture(&texture_descriptor);
257        let mut encoder = device.create_command_encoder(&Default::default());
258        generator.generate(&device, &mut encoder, &texture, &texture_descriptor)
259    }
260
261    #[test]
262    fn sanity_check() {
263        init();
264        // Generate texture data on the CPU
265        let size = 512;
266        let mip_level_count = 1 + (size as f32).log2() as u32;
267        // Create a texture
268        let format = wgpu::TextureFormat::R8Unorm;
269        let texture_extent = wgpu::Extent3d {
270            width: size,
271            height: size,
272            depth: 1,
273        };
274        let texture_descriptor = wgpu::TextureDescriptor {
275            size: texture_extent,
276            mip_level_count,
277            format,
278            sample_count: 1,
279            dimension: wgpu::TextureDimension::D2,
280            usage: ComputeMipmapGenerator::required_usage(),
281            label: None,
282        };
283        futures::executor::block_on((|| async {
284            let res = generate_test(&texture_descriptor).await;
285            assert!(res.is_ok());
286        })());
287    }
288
289    #[test]
290    fn unsupported_npot() {
291        init();
292        // Generate texture data on the CPU
293        let size = 511;
294        let mip_level_count = 1 + (size as f32).log2() as u32;
295        // Create a texture
296        let format = wgpu::TextureFormat::R8Unorm;
297        let texture_extent = wgpu::Extent3d {
298            width: size,
299            height: size,
300            depth: 1,
301        };
302        let texture_descriptor = wgpu::TextureDescriptor {
303            size: texture_extent,
304            mip_level_count,
305            format,
306            sample_count: 1,
307            dimension: wgpu::TextureDimension::D2,
308            usage: ComputeMipmapGenerator::required_usage(),
309            label: None,
310        };
311        futures::executor::block_on((|| async {
312            let res = generate_test(&texture_descriptor).await;
313            assert!(res.is_err());
314            assert!(res.err() == Some(Error::NpotTexture));
315        })());
316    }
317
318    #[test]
319    fn unsupported_usage() {
320        init();
321        // Generate texture data on the CPU
322        let size = 512;
323        let mip_level_count = 1 + (size as f32).log2() as u32;
324        // Create a texture
325        let format = wgpu::TextureFormat::R8Unorm;
326        let texture_extent = wgpu::Extent3d {
327            width: size,
328            height: size,
329            depth: 1,
330        };
331        let texture_descriptor = wgpu::TextureDescriptor {
332            size: texture_extent,
333            mip_level_count,
334            format,
335            sample_count: 1,
336            dimension: wgpu::TextureDimension::D2,
337            usage: wgpu::TextureUsage::empty(),
338            label: None,
339        };
340        futures::executor::block_on((|| async {
341            let res = generate_test(&texture_descriptor).await;
342            assert!(res.is_err());
343            assert!(res.err() == Some(Error::UnsupportedUsage(wgpu::TextureUsage::empty())));
344        })());
345    }
346
347    #[test]
348    fn unknown_format() {
349        init();
350        // Generate texture data on the CPU
351        let size = 512;
352        let mip_level_count = 1 + (size as f32).log2() as u32;
353        // Create a texture
354        let format = wgpu::TextureFormat::Rg16Sint;
355        let texture_extent = wgpu::Extent3d {
356            width: size,
357            height: size,
358            depth: 1,
359        };
360        let texture_descriptor = wgpu::TextureDescriptor {
361            size: texture_extent,
362            mip_level_count,
363            format,
364            sample_count: 1,
365            dimension: wgpu::TextureDimension::D2,
366            usage: ComputeMipmapGenerator::required_usage(),
367            label: None,
368        };
369        futures::executor::block_on((|| async {
370            let res = generate_test(&texture_descriptor).await;
371            assert!(res.is_err());
372            assert!(res.err() == Some(Error::UnknownFormat(wgpu::TextureFormat::Rg16Sint)));
373        })());
374    }
375}