librashader_runtime_wgpu/
filter_chain.rs

1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18use rayon::ThreadPoolBuilder;
19
20use crate::buffer::WgpuStagedBuffer;
21use crate::draw_quad::DrawQuad;
22use librashader_common::{FilterMode, Size, Viewport, WrapMode};
23use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
24use librashader_runtime::framebuffer::FramebufferInit;
25use librashader_runtime::render_target::RenderTarget;
26use librashader_runtime::scaling::ScaleFramebuffer;
27
28use crate::error;
29use crate::error::FilterChainError;
30use crate::filter_pass::FilterPass;
31use crate::framebuffer::WgpuOutputView;
32use crate::graphics_pipeline::WgpuGraphicsPipeline;
33use crate::luts::LutTexture;
34use crate::mipmap::MipmapGen;
35use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
36use crate::samplers::SamplerSet;
37use crate::texture::{InputImage, OwnedImage};
38
39mod compile {
40    use super::*;
41    use librashader_pack::{PassResource, TextureResource};
42
43    #[cfg(not(feature = "stable"))]
44    pub type ShaderPassMeta =
45        ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
46
47    #[cfg(feature = "stable")]
48    pub type ShaderPassMeta =
49        ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
50
51    #[cfg_attr(not(feature = "stable"), define_opaque(ShaderPassMeta))]
52    pub fn compile_passes(
53        shaders: Vec<PassResource>,
54        textures: &[TextureResource],
55    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
56        let (passes, semantics) = WGSL::compile_preset_passes::<
57            SpirvCompilation,
58            Naga,
59            FilterChainError,
60        >(shaders, textures.iter().map(|t| &t.meta))?;
61        Ok((passes, semantics))
62    }
63}
64
65use compile::{compile_passes, ShaderPassMeta};
66use librashader_pack::{ShaderPresetPack, TextureResource};
67use librashader_runtime::parameters::RuntimeParameters;
68
69/// A wgpu filter chain.
70pub struct FilterChainWgpu {
71    pub(crate) common: FilterCommon,
72    passes: Box<[FilterPass]>,
73    output_framebuffers: Box<[OwnedImage]>,
74    feedback_framebuffers: Box<[OwnedImage]>,
75    history_framebuffers: VecDeque<OwnedImage>,
76    disable_mipmaps: bool,
77    mipmapper: MipmapGen,
78    default_frame_options: FrameOptionsWgpu,
79    draw_last_pass_feedback: bool,
80}
81
82pub(crate) struct FilterCommon {
83    pub output_textures: Box<[Option<InputImage>]>,
84    pub feedback_textures: Box<[Option<InputImage>]>,
85    pub history_textures: Box<[Option<InputImage>]>,
86    pub luts: FastHashMap<usize, LutTexture>,
87    pub samplers: SamplerSet,
88    pub config: RuntimeParameters,
89    pub(crate) draw_quad: DrawQuad,
90    pub(crate) device: wgpu::Device,
91    pub(crate) queue: wgpu::Queue,
92}
93
94impl FilterChainWgpu {
95    /// Load the shader preset at the given path into a filter chain.
96    pub fn load_from_path(
97        path: impl AsRef<Path>,
98        features: ShaderFeatures,
99        device: &wgpu::Device,
100        queue: &wgpu::Queue,
101        options: Option<&FilterChainOptionsWgpu>,
102    ) -> error::Result<FilterChainWgpu> {
103        // load passes from preset
104        let preset = ShaderPreset::try_parse(path, features)?;
105
106        Self::load_from_preset(preset, device, queue, options)
107    }
108
109    /// Load a filter chain from a pre-parsed `ShaderPreset`.
110    pub fn load_from_preset(
111        preset: ShaderPreset,
112        device: &wgpu::Device,
113        queue: &wgpu::Queue,
114        options: Option<&FilterChainOptionsWgpu>,
115    ) -> error::Result<FilterChainWgpu> {
116        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
117        Self::load_from_pack(preset, device, queue, options)
118    }
119
120    /// Load a filter chain from a pre-parsed and loaded `ShaderPresetPack`.
121    pub fn load_from_pack(
122        preset: ShaderPresetPack,
123        device: &wgpu::Device,
124        queue: &wgpu::Queue,
125        options: Option<&FilterChainOptionsWgpu>,
126    ) -> error::Result<FilterChainWgpu> {
127        let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
128            label: Some("librashader load cmd"),
129        });
130        let filter_chain =
131            Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
132
133        let cmd = cmd.finish();
134
135        // Wait for device
136        let index = queue.submit([cmd]);
137        device.poll(wgpu::Maintain::WaitForSubmissionIndex(index));
138
139        Ok(filter_chain)
140    }
141
142    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
143    /// to the caller. This function therefore requires no external synchronization of the device queue.
144    ///
145    /// ## Safety
146    /// The provided command buffer must be ready for recording and contain no prior commands.
147    /// The caller is responsible for ending the command buffer and immediately submitting it to a
148    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
149    pub fn load_from_preset_deferred(
150        preset: ShaderPreset,
151        device: &wgpu::Device,
152        queue: &wgpu::Queue,
153        cmd: &mut wgpu::CommandEncoder,
154        options: Option<&FilterChainOptionsWgpu>,
155    ) -> error::Result<FilterChainWgpu> {
156        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
157        Self::load_from_pack_deferred(preset, device, queue, cmd, options)
158    }
159
160    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
161    /// to the caller. This function therefore requires no external synchronization of the device queue.
162    ///
163    /// ## Safety
164    /// The provided command buffer must be ready for recording and contain no prior commands.
165    /// The caller is responsible for ending the command buffer and immediately submitting it to a
166    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
167    pub fn load_from_pack_deferred(
168        preset: ShaderPresetPack,
169        device: &wgpu::Device,
170        queue: &wgpu::Queue,
171        cmd: &mut wgpu::CommandEncoder,
172        options: Option<&FilterChainOptionsWgpu>,
173    ) -> error::Result<FilterChainWgpu> {
174        let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
175
176        // cache is opt-in for wgpu, not opt-out because of feature requirements.
177        let disable_cache = options.map_or(true, |o| !o.enable_cache);
178
179        // initialize passes
180        let filters = Self::init_passes(
181            &device,
182            passes,
183            &semantics,
184            options.and_then(|o| o.adapter_info.as_ref()),
185            disable_cache,
186        )?;
187
188        let samplers = SamplerSet::new(&device);
189        let mut mipmapper = MipmapGen::new(&device);
190        let luts = FilterChainWgpu::load_luts(
191            &device,
192            &queue,
193            cmd,
194            &mut mipmapper,
195            &samplers,
196            preset.textures,
197        )?;
198        //
199        let framebuffer_gen = || {
200            Ok::<_, error::FilterChainError>(OwnedImage::new(
201                &device,
202                Size::new(1, 1),
203                1,
204                wgpu::TextureFormat::Bgra8Unorm,
205            ))
206        };
207        let input_gen = || None;
208        let framebuffer_init = FramebufferInit::new(
209            filters.iter().map(|f| &f.reflection.meta),
210            &framebuffer_gen,
211            &input_gen,
212        );
213
214        //
215        // // initialize output framebuffers
216        let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
217        //
218        // initialize feedback framebuffers
219        let (feedback_framebuffers, feedback_textures) =
220            framebuffer_init.init_output_framebuffers()?;
221        //
222        // initialize history
223        let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
224
225        let draw_quad = DrawQuad::new(&device);
226
227        Ok(FilterChainWgpu {
228            draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
229            common: FilterCommon {
230                luts,
231                samplers,
232                config: RuntimeParameters::new(preset.pass_count as usize, preset.parameters),
233                draw_quad,
234                device: device.clone(),
235                queue: queue.clone(),
236                output_textures,
237                feedback_textures,
238                history_textures,
239            },
240            passes: filters,
241            output_framebuffers,
242            feedback_framebuffers,
243            history_framebuffers,
244            disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
245            mipmapper,
246            default_frame_options: Default::default(),
247        })
248    }
249
250    fn load_luts(
251        device: &wgpu::Device,
252        queue: &wgpu::Queue,
253        cmd: &mut wgpu::CommandEncoder,
254        mipmapper: &mut MipmapGen,
255        sampler_set: &SamplerSet,
256        textures: Vec<TextureResource>,
257    ) -> error::Result<FastHashMap<usize, LutTexture>> {
258        let mut luts = FastHashMap::default();
259
260        #[cfg(not(target_arch = "wasm32"))]
261        let images_iter = textures.into_par_iter();
262
263        #[cfg(target_arch = "wasm32")]
264        let images_iter = textures.into_iter();
265
266        let textures = images_iter
267            .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
268            .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
269        for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
270            let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
271            luts.insert(index, texture);
272        }
273        Ok(luts)
274    }
275
276    fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
277        if let Some(mut back) = self.history_framebuffers.pop_back() {
278            if back.image.size() != input.size() || input.format() != back.image.format() {
279                // old back will get dropped.. do we need to defer?
280                let _old_back = std::mem::replace(
281                    &mut back,
282                    OwnedImage::new(
283                        &self.common.device,
284                        input.size().into(),
285                        1,
286                        input.format().into(),
287                    ),
288                );
289            }
290
291            back.copy_from(&self.common.device, cmd, input);
292
293            self.history_framebuffers.push_front(back)
294        }
295    }
296
297    fn init_passes(
298        device: &wgpu::Device,
299        passes: Vec<ShaderPassMeta>,
300        semantics: &ShaderSemantics,
301        adapter_info: Option<&wgpu::AdapterInfo>,
302        disable_cache: bool,
303    ) -> error::Result<Box<[FilterPass]>> {
304        #[cfg(not(target_arch = "wasm32"))]
305        let filter_creation_fn = || {
306            let passes_iter = passes.into_par_iter();
307            #[cfg(target_arch = "wasm32")]
308            let passes_iter = passes.into_iter();
309
310            let filters: Vec<error::Result<FilterPass>> = passes_iter
311                .enumerate()
312                .map(|(index, (config, mut reflect))| {
313                    let reflection = reflect.reflect(index, semantics)?;
314                    let wgsl = reflect.compile(NagaLoweringOptions {
315                        write_pcb_as_ubo: true,
316                        sampler_bind_group: 1,
317                    })?;
318
319                    let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
320                    let push_size = reflection
321                        .push_constant
322                        .as_ref()
323                        .map_or(0, |push| push.size as wgpu::BufferAddress);
324
325                    let uniform_storage = UniformStorage::new_with_storage(
326                        WgpuStagedBuffer::new(
327                            &device,
328                            wgpu::BufferUsages::UNIFORM,
329                            ubo_size as wgpu::BufferAddress,
330                            Some("ubo"),
331                        ),
332                        WgpuStagedBuffer::new(
333                            &device,
334                            wgpu::BufferUsages::UNIFORM,
335                            push_size as wgpu::BufferAddress,
336                            Some("push"),
337                        ),
338                    );
339
340                    let uniform_bindings =
341                        reflection.meta.create_binding_map(|param| param.offset());
342
343                    let render_pass_format: Option<wgpu::TextureFormat> =
344                        if let Some(format) = config.meta.get_format_override() {
345                            format.into()
346                        } else {
347                            config.data.format.into()
348                        };
349
350                    let graphics_pipeline = WgpuGraphicsPipeline::new(
351                        &device,
352                        &wgsl,
353                        &reflection,
354                        render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
355                        adapter_info,
356                        disable_cache,
357                    );
358
359                    Ok(FilterPass {
360                        reflection,
361                        uniform_storage,
362                        uniform_bindings,
363                        source: config.data,
364                        meta: config.meta,
365                        graphics_pipeline,
366                    })
367                })
368                .collect();
369            filters
370        };
371
372        #[cfg(target_arch = "wasm32")]
373        let filters = filter_creation_fn();
374
375        #[cfg(not(target_arch = "wasm32"))]
376        let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
377            // naga compilations can possibly use degenerate stack sizes.
378            .stack_size(10 * 1048576)
379            .build()
380        {
381            thread_pool.install(|| filter_creation_fn())
382        } else {
383            filter_creation_fn()
384        };
385
386        let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
387        let filters = filters?;
388        Ok(filters.into_boxed_slice())
389    }
390
391    /// Records shader rendering commands to the provided command encoder.
392    pub fn frame<'a>(
393        &mut self,
394        input: &wgpu::Texture,
395        viewport: &Viewport<WgpuOutputView<'a>>,
396        cmd: &mut wgpu::CommandEncoder,
397        frame_count: usize,
398        options: Option<&FrameOptionsWgpu>,
399    ) -> error::Result<()> {
400        let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
401        let passes = &mut self.passes[0..max];
402
403        if let Some(options) = &options {
404            if options.clear_history {
405                for history in &mut self.history_framebuffers {
406                    history.clear(cmd);
407                }
408            }
409        }
410
411        if passes.is_empty() {
412            return Ok(());
413        }
414
415        let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
416
417        let filter = passes[0].meta.filter;
418        let wrap_mode = passes[0].meta.wrap_mode;
419
420        // update history
421        for (texture, image) in self
422            .common
423            .history_textures
424            .iter_mut()
425            .zip(self.history_framebuffers.iter())
426        {
427            *texture = Some(image.as_input(filter, wrap_mode));
428        }
429
430        let original = InputImage {
431            image: input.clone(),
432            view: original_image_view,
433            wrap_mode,
434            filter_mode: filter,
435            mip_filter: filter,
436        };
437
438        let mut source = original.clone();
439
440        // swap output and feedback **before** recording command buffers
441        std::mem::swap(
442            &mut self.output_framebuffers,
443            &mut self.feedback_framebuffers,
444        );
445
446        // rescale render buffers to ensure all bindings are valid.
447        OwnedImage::scale_framebuffers_with_context(
448            source.image.size().into(),
449            viewport.output.size,
450            original.image.size().into(),
451            &mut self.output_framebuffers,
452            &mut self.feedback_framebuffers,
453            passes,
454            &self.common.device,
455            Some(&mut |index: usize,
456                       pass: &FilterPass,
457                       output: &OwnedImage,
458                       feedback: &OwnedImage| {
459                // refresh inputs
460                self.common.feedback_textures[index] =
461                    Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
462                self.common.output_textures[index] =
463                    Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode));
464                Ok(())
465            }),
466        )?;
467
468        let passes_len = passes.len();
469        let (pass, last) = passes.split_at_mut(passes_len - 1);
470
471        let options = options.unwrap_or(&self.default_frame_options);
472
473        for (index, pass) in pass.iter_mut().enumerate() {
474            source.filter_mode = pass.meta.filter;
475            source.wrap_mode = pass.meta.wrap_mode;
476            source.mip_filter = pass.meta.filter;
477
478            let target = &self.output_framebuffers[index];
479            let output_image = WgpuOutputView::from(target);
480            let out = RenderTarget::identity(&output_image)?;
481
482            pass.draw(
483                cmd,
484                index,
485                &self.common,
486                pass.meta.get_frame_count(frame_count),
487                options,
488                viewport,
489                &original,
490                &source,
491                &out,
492                QuadType::Offscreen,
493            )?;
494
495            if target.max_miplevels > 1 && !self.disable_mipmaps {
496                let sampler = self.common.samplers.get(
497                    WrapMode::ClampToEdge,
498                    FilterMode::Linear,
499                    FilterMode::Nearest,
500                );
501
502                target.generate_mipmaps(&self.common.device, cmd, &mut self.mipmapper, &sampler);
503            }
504
505            source = self.common.output_textures[index].clone().unwrap();
506        }
507
508        // try to hint the optimizer
509        assert_eq!(last.len(), 1);
510
511        if let Some(pass) = last.iter_mut().next() {
512            let index = passes_len - 1;
513            if !pass.graphics_pipeline.has_format(viewport.output.format) {
514                // need to recompile
515                pass.graphics_pipeline
516                    .recompile(&self.common.device, viewport.output.format);
517            }
518
519            source.filter_mode = pass.meta.filter;
520            source.wrap_mode = pass.meta.wrap_mode;
521            source.mip_filter = pass.meta.filter;
522
523            if self.draw_last_pass_feedback {
524                let target = &self.output_framebuffers[index];
525                let output_image = WgpuOutputView::from(target);
526                let out = RenderTarget::viewport_with_output(&output_image, viewport);
527
528                pass.draw(
529                    cmd,
530                    index,
531                    &self.common,
532                    pass.meta.get_frame_count(frame_count),
533                    options,
534                    viewport,
535                    &original,
536                    &source,
537                    &out,
538                    QuadType::Final,
539                )?;
540            }
541
542            let out = RenderTarget::viewport(viewport);
543            pass.draw(
544                cmd,
545                index,
546                &self.common,
547                pass.meta.get_frame_count(frame_count),
548                options,
549                viewport,
550                &original,
551                &source,
552                &out,
553                QuadType::Final,
554            )?;
555        }
556
557        self.push_history(&input, cmd);
558        Ok(())
559    }
560}