librashader_runtime_wgpu/
filter_chain.rs

1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18use rayon::ThreadPoolBuilder;
19use std::sync::Arc;
20
21use crate::buffer::WgpuStagedBuffer;
22use crate::draw_quad::DrawQuad;
23use librashader_common::{FilterMode, Size, Viewport, WrapMode};
24use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
25use librashader_runtime::framebuffer::FramebufferInit;
26use librashader_runtime::render_target::RenderTarget;
27use librashader_runtime::scaling::ScaleFramebuffer;
28use wgpu::{Device, TextureFormat};
29
30use crate::error;
31use crate::error::FilterChainError;
32use crate::filter_pass::FilterPass;
33use crate::framebuffer::WgpuOutputView;
34use crate::graphics_pipeline::WgpuGraphicsPipeline;
35use crate::luts::LutTexture;
36use crate::mipmap::MipmapGen;
37use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
38use crate::samplers::SamplerSet;
39use crate::texture::{InputImage, OwnedImage};
40
41mod compile {
42    use super::*;
43    use librashader_pack::{PassResource, TextureResource};
44
45    #[cfg(not(feature = "stable"))]
46    pub type ShaderPassMeta =
47        ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
48
49    #[cfg(feature = "stable")]
50    pub type ShaderPassMeta =
51        ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
52
53    pub fn compile_passes(
54        shaders: Vec<PassResource>,
55        textures: &[TextureResource],
56    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
57        let (passes, semantics) = WGSL::compile_preset_passes::<
58            SpirvCompilation,
59            Naga,
60            FilterChainError,
61        >(shaders, textures.iter().map(|t| &t.meta))?;
62        Ok((passes, semantics))
63    }
64}
65
66use compile::{compile_passes, ShaderPassMeta};
67use librashader_pack::{ShaderPresetPack, TextureResource};
68use librashader_runtime::parameters::RuntimeParameters;
69
70/// A wgpu filter chain.
71pub struct FilterChainWgpu {
72    pub(crate) common: FilterCommon,
73    passes: Box<[FilterPass]>,
74    output_framebuffers: Box<[OwnedImage]>,
75    feedback_framebuffers: Box<[OwnedImage]>,
76    history_framebuffers: VecDeque<OwnedImage>,
77    disable_mipmaps: bool,
78    mipmapper: MipmapGen,
79    default_frame_options: FrameOptionsWgpu,
80    draw_last_pass_feedback: bool,
81}
82
83pub(crate) struct FilterCommon {
84    pub output_textures: Box<[Option<InputImage>]>,
85    pub feedback_textures: Box<[Option<InputImage>]>,
86    pub history_textures: Box<[Option<InputImage>]>,
87    pub luts: FastHashMap<usize, LutTexture>,
88    pub samplers: SamplerSet,
89    pub config: RuntimeParameters,
90    pub(crate) draw_quad: DrawQuad,
91    pub(crate) device: Arc<Device>,
92    pub(crate) queue: Arc<wgpu::Queue>,
93}
94
95impl FilterChainWgpu {
96    /// Load the shader preset at the given path into a filter chain.
97    pub fn load_from_path(
98        path: impl AsRef<Path>,
99        features: ShaderFeatures,
100        device: Arc<Device>,
101        queue: Arc<wgpu::Queue>,
102        options: Option<&FilterChainOptionsWgpu>,
103    ) -> error::Result<FilterChainWgpu> {
104        // load passes from preset
105        let preset = ShaderPreset::try_parse(path, features)?;
106
107        Self::load_from_preset(preset, device, queue, options)
108    }
109
110    /// Load a filter chain from a pre-parsed `ShaderPreset`.
111    pub fn load_from_preset(
112        preset: ShaderPreset,
113        device: Arc<Device>,
114        queue: Arc<wgpu::Queue>,
115        options: Option<&FilterChainOptionsWgpu>,
116    ) -> error::Result<FilterChainWgpu> {
117        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
118        Self::load_from_pack(preset, device, queue, options)
119    }
120
121    /// Load a filter chain from a pre-parsed and loaded `ShaderPresetPack`.
122    pub fn load_from_pack(
123        preset: ShaderPresetPack,
124        device: Arc<Device>,
125        queue: Arc<wgpu::Queue>,
126        options: Option<&FilterChainOptionsWgpu>,
127    ) -> error::Result<FilterChainWgpu> {
128        let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
129            label: Some("librashader load cmd"),
130        });
131        let filter_chain = Self::load_from_pack_deferred(
132            preset,
133            Arc::clone(&device),
134            Arc::clone(&queue),
135            &mut cmd,
136            options,
137        )?;
138
139        let cmd = cmd.finish();
140
141        // Wait for device
142        let index = queue.submit([cmd]);
143        device.poll(wgpu::Maintain::WaitForSubmissionIndex(index));
144
145        Ok(filter_chain)
146    }
147
148    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
149    /// to the caller. This function therefore requires no external synchronization of the device queue.
150    ///
151    /// ## Safety
152    /// The provided command buffer must be ready for recording and contain no prior commands.
153    /// The caller is responsible for ending the command buffer and immediately submitting it to a
154    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
155    pub fn load_from_preset_deferred(
156        preset: ShaderPreset,
157        device: Arc<Device>,
158        queue: Arc<wgpu::Queue>,
159        cmd: &mut wgpu::CommandEncoder,
160        options: Option<&FilterChainOptionsWgpu>,
161    ) -> error::Result<FilterChainWgpu> {
162        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
163        Self::load_from_pack_deferred(preset, device, queue, cmd, options)
164    }
165
166    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
167    /// to the caller. This function therefore requires no external synchronization of the device queue.
168    ///
169    /// ## Safety
170    /// The provided command buffer must be ready for recording and contain no prior commands.
171    /// The caller is responsible for ending the command buffer and immediately submitting it to a
172    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
173    pub fn load_from_pack_deferred(
174        preset: ShaderPresetPack,
175        device: Arc<Device>,
176        queue: Arc<wgpu::Queue>,
177        cmd: &mut wgpu::CommandEncoder,
178        options: Option<&FilterChainOptionsWgpu>,
179    ) -> error::Result<FilterChainWgpu> {
180        let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
181
182        // cache is opt-in for wgpu, not opt-out because of feature requirements.
183        let disable_cache = options.map_or(true, |o| !o.enable_cache);
184
185        // initialize passes
186        let filters = Self::init_passes(
187            &device,
188            passes,
189            &semantics,
190            options.and_then(|o| o.adapter_info.as_ref()),
191            disable_cache,
192        )?;
193
194        let samplers = SamplerSet::new(&device);
195        let mut mipmapper = MipmapGen::new(&device);
196        let luts = FilterChainWgpu::load_luts(
197            &device,
198            &queue,
199            cmd,
200            &mut mipmapper,
201            &samplers,
202            preset.textures,
203        )?;
204        //
205        let framebuffer_gen = || {
206            Ok::<_, error::FilterChainError>(OwnedImage::new(
207                &device,
208                Size::new(1, 1),
209                1,
210                TextureFormat::Bgra8Unorm,
211            ))
212        };
213        let input_gen = || None;
214        let framebuffer_init = FramebufferInit::new(
215            filters.iter().map(|f| &f.reflection.meta),
216            &framebuffer_gen,
217            &input_gen,
218        );
219
220        //
221        // // initialize output framebuffers
222        let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
223        //
224        // initialize feedback framebuffers
225        let (feedback_framebuffers, feedback_textures) =
226            framebuffer_init.init_output_framebuffers()?;
227        //
228        // initialize history
229        let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
230
231        let draw_quad = DrawQuad::new(&device);
232
233        Ok(FilterChainWgpu {
234            draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
235            common: FilterCommon {
236                luts,
237                samplers,
238                config: RuntimeParameters::new(preset.pass_count as usize, preset.parameters),
239                draw_quad,
240                device,
241                queue,
242                output_textures,
243                feedback_textures,
244                history_textures,
245            },
246            passes: filters,
247            output_framebuffers,
248            feedback_framebuffers,
249            history_framebuffers,
250            disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
251            mipmapper,
252            default_frame_options: Default::default(),
253        })
254    }
255
256    fn load_luts(
257        device: &wgpu::Device,
258        queue: &wgpu::Queue,
259        cmd: &mut wgpu::CommandEncoder,
260        mipmapper: &mut MipmapGen,
261        sampler_set: &SamplerSet,
262        textures: Vec<TextureResource>,
263    ) -> error::Result<FastHashMap<usize, LutTexture>> {
264        let mut luts = FastHashMap::default();
265
266        #[cfg(not(target_arch = "wasm32"))]
267        let images_iter = textures.into_par_iter();
268
269        #[cfg(target_arch = "wasm32")]
270        let images_iter = textures.into_iter();
271
272        let textures = images_iter
273            .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
274            .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
275        for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
276            let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
277            luts.insert(index, texture);
278        }
279        Ok(luts)
280    }
281
282    fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
283        if let Some(mut back) = self.history_framebuffers.pop_back() {
284            if back.image.size() != input.size() || input.format() != back.image.format() {
285                // old back will get dropped.. do we need to defer?
286                let _old_back = std::mem::replace(
287                    &mut back,
288                    OwnedImage::new(
289                        &self.common.device,
290                        input.size().into(),
291                        1,
292                        input.format().into(),
293                    ),
294                );
295            }
296
297            back.copy_from(&self.common.device, cmd, input);
298
299            self.history_framebuffers.push_front(back)
300        }
301    }
302
303    fn init_passes(
304        device: &wgpu::Device,
305        passes: Vec<ShaderPassMeta>,
306        semantics: &ShaderSemantics,
307        adapter_info: Option<&wgpu::AdapterInfo>,
308        disable_cache: bool,
309    ) -> error::Result<Box<[FilterPass]>> {
310        #[cfg(not(target_arch = "wasm32"))]
311        let filter_creation_fn = || {
312            let passes_iter = passes.into_par_iter();
313            #[cfg(target_arch = "wasm32")]
314            let passes_iter = passes.into_iter();
315
316            let filters: Vec<error::Result<FilterPass>> = passes_iter
317                .enumerate()
318                .map(|(index, (config, mut reflect))| {
319                    let reflection = reflect.reflect(index, semantics)?;
320                    let wgsl = reflect.compile(NagaLoweringOptions {
321                        write_pcb_as_ubo: true,
322                        sampler_bind_group: 1,
323                    })?;
324
325                    let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
326                    let push_size = reflection
327                        .push_constant
328                        .as_ref()
329                        .map_or(0, |push| push.size as wgpu::BufferAddress);
330
331                    let uniform_storage = UniformStorage::new_with_storage(
332                        WgpuStagedBuffer::new(
333                            &device,
334                            wgpu::BufferUsages::UNIFORM,
335                            ubo_size as wgpu::BufferAddress,
336                            Some("ubo"),
337                        ),
338                        WgpuStagedBuffer::new(
339                            &device,
340                            wgpu::BufferUsages::UNIFORM,
341                            push_size as wgpu::BufferAddress,
342                            Some("push"),
343                        ),
344                    );
345
346                    let uniform_bindings =
347                        reflection.meta.create_binding_map(|param| param.offset());
348
349                    let render_pass_format: Option<TextureFormat> =
350                        if let Some(format) = config.meta.get_format_override() {
351                            format.into()
352                        } else {
353                            config.data.format.into()
354                        };
355
356                    let graphics_pipeline = WgpuGraphicsPipeline::new(
357                        &device,
358                        &wgsl,
359                        &reflection,
360                        render_pass_format.unwrap_or(TextureFormat::Rgba8Unorm),
361                        adapter_info,
362                        disable_cache,
363                    );
364
365                    Ok(FilterPass {
366                        reflection,
367                        uniform_storage,
368                        uniform_bindings,
369                        source: config.data,
370                        meta: config.meta,
371                        graphics_pipeline,
372                    })
373                })
374                .collect();
375            filters
376        };
377
378        #[cfg(target_arch = "wasm32")]
379        let filters = filter_creation_fn();
380
381        #[cfg(not(target_arch = "wasm32"))]
382        let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
383            // naga compilations can possibly use degenerate stack sizes.
384            .stack_size(10 * 1048576)
385            .build()
386        {
387            thread_pool.install(|| filter_creation_fn())
388        } else {
389            filter_creation_fn()
390        };
391
392        let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
393        let filters = filters?;
394        Ok(filters.into_boxed_slice())
395    }
396
397    /// Records shader rendering commands to the provided command encoder.
398    pub fn frame<'a>(
399        &mut self,
400        input: Arc<wgpu::Texture>,
401        viewport: &Viewport<WgpuOutputView<'a>>,
402        cmd: &mut wgpu::CommandEncoder,
403        frame_count: usize,
404        options: Option<&FrameOptionsWgpu>,
405    ) -> error::Result<()> {
406        let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
407        let passes = &mut self.passes[0..max];
408
409        if let Some(options) = &options {
410            if options.clear_history {
411                for history in &mut self.history_framebuffers {
412                    history.clear(cmd);
413                }
414            }
415        }
416
417        if passes.is_empty() {
418            return Ok(());
419        }
420
421        let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
422
423        let filter = passes[0].meta.filter;
424        let wrap_mode = passes[0].meta.wrap_mode;
425
426        // update history
427        for (texture, image) in self
428            .common
429            .history_textures
430            .iter_mut()
431            .zip(self.history_framebuffers.iter())
432        {
433            *texture = Some(image.as_input(filter, wrap_mode));
434        }
435
436        let original = InputImage {
437            image: Arc::clone(&input),
438            view: Arc::new(original_image_view),
439            wrap_mode,
440            filter_mode: filter,
441            mip_filter: filter,
442        };
443
444        let mut source = original.clone();
445
446        // swap output and feedback **before** recording command buffers
447        std::mem::swap(
448            &mut self.output_framebuffers,
449            &mut self.feedback_framebuffers,
450        );
451
452        // rescale render buffers to ensure all bindings are valid.
453        OwnedImage::scale_framebuffers_with_context(
454            source.image.size().into(),
455            viewport.output.size,
456            original.image.size().into(),
457            &mut self.output_framebuffers,
458            &mut self.feedback_framebuffers,
459            passes,
460            &self.common.device,
461            Some(&mut |index: usize,
462                       pass: &FilterPass,
463                       output: &OwnedImage,
464                       feedback: &OwnedImage| {
465                // refresh inputs
466                self.common.feedback_textures[index] =
467                    Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
468                self.common.output_textures[index] =
469                    Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode));
470                Ok(())
471            }),
472        )?;
473
474        let passes_len = passes.len();
475        let (pass, last) = passes.split_at_mut(passes_len - 1);
476
477        let options = options.unwrap_or(&self.default_frame_options);
478
479        for (index, pass) in pass.iter_mut().enumerate() {
480            source.filter_mode = pass.meta.filter;
481            source.wrap_mode = pass.meta.wrap_mode;
482            source.mip_filter = pass.meta.filter;
483
484            let target = &self.output_framebuffers[index];
485            let output_image = WgpuOutputView::from(target);
486            let out = RenderTarget::identity(&output_image)?;
487
488            pass.draw(
489                cmd,
490                index,
491                &self.common,
492                pass.meta.get_frame_count(frame_count),
493                options,
494                viewport,
495                &original,
496                &source,
497                &out,
498                QuadType::Offscreen,
499            )?;
500
501            if target.max_miplevels > 1 && !self.disable_mipmaps {
502                let sampler = self.common.samplers.get(
503                    WrapMode::ClampToEdge,
504                    FilterMode::Linear,
505                    FilterMode::Nearest,
506                );
507
508                target.generate_mipmaps(&self.common.device, cmd, &mut self.mipmapper, &sampler);
509            }
510
511            source = self.common.output_textures[index].clone().unwrap();
512        }
513
514        // try to hint the optimizer
515        assert_eq!(last.len(), 1);
516
517        if let Some(pass) = last.iter_mut().next() {
518            let index = passes_len - 1;
519            if !pass.graphics_pipeline.has_format(viewport.output.format) {
520                // need to recompile
521                pass.graphics_pipeline
522                    .recompile(&self.common.device, viewport.output.format);
523            }
524
525            source.filter_mode = pass.meta.filter;
526            source.wrap_mode = pass.meta.wrap_mode;
527            source.mip_filter = pass.meta.filter;
528
529            if self.draw_last_pass_feedback {
530                let target = &self.output_framebuffers[index];
531                let output_image = WgpuOutputView::from(target);
532                let out = RenderTarget::viewport_with_output(&output_image, viewport);
533
534                pass.draw(
535                    cmd,
536                    index,
537                    &self.common,
538                    pass.meta.get_frame_count(frame_count),
539                    options,
540                    viewport,
541                    &original,
542                    &source,
543                    &out,
544                    QuadType::Final,
545                )?;
546            }
547
548            let out = RenderTarget::viewport(viewport);
549            pass.draw(
550                cmd,
551                index,
552                &self.common,
553                pass.meta.get_frame_count(frame_count),
554                options,
555                viewport,
556                &original,
557                &source,
558                &out,
559                QuadType::Final,
560            )?;
561        }
562
563        self.push_history(&input, cmd);
564        Ok(())
565    }
566}