Skip to main content

librashader_runtime_wgpu/
filter_chain.rs

1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::ThreadPoolBuilder;
20
21use crate::buffer::WgpuStagedBuffer;
22use crate::draw_quad::DrawQuad;
23use librashader_common::{FilterMode, Size, Viewport, WrapMode};
24use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
25use librashader_runtime::framebuffer::{FramebufferInit, FramebufferPool};
26use librashader_runtime::render_target::RenderTarget;
27use librashader_runtime::scaling::ScaleFramebuffer;
28
29use crate::error;
30use crate::error::FilterChainError;
31use crate::filter_pass::FilterPass;
32use crate::framebuffer::WgpuOutputView;
33use crate::graphics_pipeline::WgpuGraphicsPipeline;
34use crate::luts::LutTexture;
35use crate::mipmap::MipmapGen;
36use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
37use crate::samplers::SamplerSet;
38use crate::texture::{InputImage, OwnedImage};
39
40#[cfg(feature = "native")]
41mod compile {
42    use super::*;
43    use librashader_pack::{PassResource, TextureResource};
44
45    #[cfg(feature = "nightly")]
46    pub type ShaderPassMeta =
47        ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
48
49    #[cfg(not(feature = "nightly"))]
50    pub type ShaderPassMeta =
51        ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
52
53    #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
54    pub fn compile_passes(
55        shaders: Vec<PassResource>,
56        textures: &[TextureResource],
57    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
58        let (passes, semantics) = WGSL::compile_preset_passes::<
59            SpirvCompilation,
60            Naga,
61            FilterChainError,
62        >(shaders, textures.iter().map(|t| &t.meta))?;
63        Ok((passes, semantics))
64    }
65}
66
67#[cfg(feature = "wgsl_preset_pack")]
68mod compile_wgsl {
69    use super::*;
70    use librashader_pack::{PassResource, TextureResource};
71    use librashader_reflect::front::WgslCompilation;
72
73    #[cfg(feature = "nightly")]
74    pub type ShaderPassMeta =
75        ShaderPassArtifact<impl CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>;
76
77    #[cfg(not(feature = "nightly"))]
78    pub type ShaderPassMeta =
79        ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>>;
80
81    #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
82    pub fn compile_passes(
83        shaders: Vec<PassResource>,
84        textures: &[TextureResource],
85    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
86        let (passes, semantics) = WGSL::compile_preset_passes::<
87            WgslCompilation,
88            Naga,
89            FilterChainError,
90        >(shaders, textures.iter().map(|t| &t.meta))?;
91        Ok((passes, semantics))
92    }
93}
94
95#[cfg(all(feature = "native", not(feature = "wgsl_preset_pack")))]
96use compile::{compile_passes, ShaderPassMeta};
97
98#[cfg(any(not(feature = "native"), feature = "wgsl_preset_pack"))]
99use compile_wgsl::{compile_passes, ShaderPassMeta};
100
101use librashader_pack::{ShaderPresetPack, TextureResource};
102use librashader_runtime::parameters::RuntimeParameters;
103
104/// A wgpu filter chain.
105pub struct FilterChainWgpu {
106    pub(crate) common: FilterCommon,
107    passes: Box<[FilterPass]>,
108    output_framebuffers: FramebufferPool<OwnedImage>,
109    feedback_framebuffers: FramebufferPool<OwnedImage>,
110    history_framebuffers: VecDeque<OwnedImage>,
111    disable_mipmaps: bool,
112    mipmapper: MipmapGen,
113    default_frame_options: FrameOptionsWgpu,
114    draw_last_pass_feedback: bool,
115}
116
117pub(crate) struct FilterCommon {
118    pub output_textures: Box<[Option<InputImage>]>,
119    pub feedback_textures: Box<[Option<InputImage>]>,
120    pub history_textures: Box<[Option<InputImage>]>,
121    pub luts: FastHashMap<usize, LutTexture>,
122    pub samplers: SamplerSet,
123    pub config: RuntimeParameters,
124    pub(crate) draw_quad: DrawQuad,
125    pub(crate) device: wgpu::Device,
126    pub(crate) queue: wgpu::Queue,
127}
128
129impl FilterChainWgpu {
130    /// Load the shader preset at the given path into a filter chain.
131    #[cfg(feature = "native")]
132    pub fn load_from_path(
133        path: impl AsRef<Path>,
134        features: ShaderFeatures,
135        device: &wgpu::Device,
136        queue: &wgpu::Queue,
137        options: Option<&FilterChainOptionsWgpu>,
138    ) -> error::Result<FilterChainWgpu> {
139        // load passes from preset
140        let preset = ShaderPreset::try_parse(path, features)?;
141
142        Self::load_from_preset(preset, device, queue, options)
143    }
144
145    /// Load a filter chain from a pre-parsed `ShaderPreset`.
146    #[cfg(feature = "native")]
147    pub fn load_from_preset(
148        preset: ShaderPreset,
149        device: &wgpu::Device,
150        queue: &wgpu::Queue,
151        options: Option<&FilterChainOptionsWgpu>,
152    ) -> error::Result<FilterChainWgpu> {
153        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
154        Self::load_from_pack(preset, device, queue, options)
155    }
156
157    /// Load a filter chain from a pre-parsed and loaded `ShaderPresetPack`.
158    pub fn load_from_pack(
159        preset: ShaderPresetPack,
160        device: &wgpu::Device,
161        queue: &wgpu::Queue,
162        options: Option<&FilterChainOptionsWgpu>,
163    ) -> error::Result<FilterChainWgpu> {
164        let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
165            label: Some("librashader load cmd"),
166        });
167        let filter_chain =
168            Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
169
170        let cmd = cmd.finish();
171
172        // Wait for device
173        let index = queue.submit([cmd]);
174        device.poll(wgpu::PollType::Wait {
175            submission_index: Some(index),
176            timeout: None,
177        })?;
178
179        Ok(filter_chain)
180    }
181
182    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
183    /// to the caller. This function therefore requires no external synchronization of the device queue.
184    ///
185    /// ## Safety
186    /// The provided command buffer must be ready for recording and contain no prior commands.
187    /// The caller is responsible for ending the command buffer and immediately submitting it to a
188    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
189    #[cfg(feature = "native")]
190    pub fn load_from_preset_deferred(
191        preset: ShaderPreset,
192        device: &wgpu::Device,
193        queue: &wgpu::Queue,
194        cmd: &mut wgpu::CommandEncoder,
195        options: Option<&FilterChainOptionsWgpu>,
196    ) -> error::Result<FilterChainWgpu> {
197        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
198        Self::load_from_pack_deferred(preset, device, queue, cmd, options)
199    }
200
201    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
202    /// to the caller. This function therefore requires no external synchronization of the device queue.
203    ///
204    /// ## Safety
205    /// The provided command buffer must be ready for recording and contain no prior commands.
206    /// The caller is responsible for ending the command buffer and immediately submitting it to a
207    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
208    pub fn load_from_pack_deferred(
209        preset: ShaderPresetPack,
210        device: &wgpu::Device,
211        queue: &wgpu::Queue,
212        cmd: &mut wgpu::CommandEncoder,
213        options: Option<&FilterChainOptionsWgpu>,
214    ) -> error::Result<FilterChainWgpu> {
215        let config = RuntimeParameters::new(&preset);
216
217        let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
218
219        // cache is opt-in for wgpu, not opt-out because of feature requirements.
220        let disable_cache = options.map_or(true, |o| !o.enable_cache);
221
222        // initialize passes
223        let filters = Self::init_passes(
224            &device,
225            passes,
226            &semantics,
227            options.and_then(|o| o.adapter_info.as_ref()),
228            disable_cache,
229        )?;
230
231        let samplers = SamplerSet::new(&device);
232        let mut mipmapper = MipmapGen::new(&device);
233        let luts = FilterChainWgpu::load_luts(
234            &device,
235            &queue,
236            cmd,
237            &mut mipmapper,
238            &samplers,
239            preset.textures,
240        )?;
241        //
242        let framebuffer_gen = || {
243            Ok::<_, error::FilterChainError>(OwnedImage::new(
244                &device,
245                Size::new(1, 1),
246                1,
247                wgpu::TextureFormat::Bgra8Unorm,
248            ))
249        };
250        let input_gen = || None;
251        let framebuffer_init = FramebufferInit::new(
252            filters.iter().map(|f| &f.reflection.meta),
253            &framebuffer_gen,
254            &input_gen,
255        );
256
257        //
258        // // initialize output framebuffers
259        let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
260        //
261        // initialize feedback framebuffers
262        let (feedback_framebuffers, feedback_textures) =
263            framebuffer_init.init_feedback_framebuffers()?;
264        //
265        // initialize history
266        let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
267
268        let draw_quad = DrawQuad::new(&device);
269
270        Ok(FilterChainWgpu {
271            draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
272            common: FilterCommon {
273                luts,
274                samplers,
275                config,
276                draw_quad,
277                device: device.clone(),
278                queue: queue.clone(),
279                output_textures,
280                feedback_textures,
281                history_textures,
282            },
283            passes: filters,
284            output_framebuffers,
285            feedback_framebuffers,
286            history_framebuffers,
287            disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
288            mipmapper,
289            default_frame_options: Default::default(),
290        })
291    }
292
293    fn load_luts(
294        device: &wgpu::Device,
295        queue: &wgpu::Queue,
296        cmd: &mut wgpu::CommandEncoder,
297        mipmapper: &mut MipmapGen,
298        sampler_set: &SamplerSet,
299        textures: Vec<TextureResource>,
300    ) -> error::Result<FastHashMap<usize, LutTexture>> {
301        let mut luts = FastHashMap::default();
302
303        #[cfg(not(target_arch = "wasm32"))]
304        let images_iter = textures.into_par_iter();
305
306        #[cfg(target_arch = "wasm32")]
307        let images_iter = textures.into_iter();
308
309        let textures = images_iter
310            .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
311            .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
312        for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
313            let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
314            luts.insert(index, texture);
315        }
316        Ok(luts)
317    }
318
319    fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
320        if let Some(mut back) = self.history_framebuffers.pop_back() {
321            if back.image.size() != input.size() || input.format() != back.image.format() {
322                // old back will get dropped.. do we need to defer?
323                let _old_back = std::mem::replace(
324                    &mut back,
325                    OwnedImage::new(
326                        &self.common.device,
327                        input.size().into(),
328                        1,
329                        input.format().into(),
330                    ),
331                );
332            }
333
334            back.copy_from(&self.common.device, cmd, input);
335
336            self.history_framebuffers.push_front(back)
337        }
338    }
339
340    fn init_passes(
341        device: &wgpu::Device,
342        passes: Vec<ShaderPassMeta>,
343        semantics: &ShaderSemantics,
344        adapter_info: Option<&wgpu::AdapterInfo>,
345        disable_cache: bool,
346    ) -> error::Result<Box<[FilterPass]>> {
347        let filter_creation_fn = || {
348            #[cfg(not(target_arch = "wasm32"))]
349            let passes_iter = passes.into_par_iter();
350            #[cfg(target_arch = "wasm32")]
351            let passes_iter = passes.into_iter();
352
353            let filters: Vec<error::Result<FilterPass>> = passes_iter
354                .enumerate()
355                .map(|(index, (config, mut reflect))| {
356                    let reflection = reflect.reflect(index, semantics)?;
357                    let wgsl = reflect.compile(NagaLoweringOptions {
358                        write_pcb_as_ubo: true,
359                        sampler_bind_group: 1,
360                        suppress_derivative_uniformity: true,
361                    })?;
362
363                    let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
364                    let push_size = reflection
365                        .push_constant
366                        .as_ref()
367                        .map_or(0, |push| push.size as wgpu::BufferAddress);
368
369                    let uniform_storage = UniformStorage::new_with_storage(
370                        WgpuStagedBuffer::new(
371                            &device,
372                            wgpu::BufferUsages::UNIFORM,
373                            ubo_size as wgpu::BufferAddress,
374                            Some("ubo"),
375                        ),
376                        WgpuStagedBuffer::new(
377                            &device,
378                            wgpu::BufferUsages::UNIFORM,
379                            push_size as wgpu::BufferAddress,
380                            Some("push"),
381                        ),
382                    );
383
384                    let uniform_bindings =
385                        reflection.meta.create_binding_map(|param| param.offset());
386
387                    let render_pass_format: Option<wgpu::TextureFormat> =
388                        if let Some(format) = config.meta.get_format_override() {
389                            format.into()
390                        } else {
391                            config.data.format.into()
392                        };
393
394                    let graphics_pipeline = WgpuGraphicsPipeline::new(
395                        &device,
396                        &wgsl,
397                        &reflection,
398                        render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
399                        adapter_info,
400                        disable_cache,
401                    );
402
403                    Ok(FilterPass {
404                        reflection,
405                        uniform_storage,
406                        uniform_bindings,
407                        source: config.data,
408                        meta: config.meta,
409                        graphics_pipeline,
410                    })
411                })
412                .collect();
413            filters
414        };
415
416        #[cfg(target_arch = "wasm32")]
417        let filters = filter_creation_fn();
418
419        #[cfg(not(target_arch = "wasm32"))]
420        let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
421            // naga compilations can possibly use degenerate stack sizes.
422            .stack_size(10 * 1048576)
423            .build()
424        {
425            thread_pool.install(|| filter_creation_fn())
426        } else {
427            filter_creation_fn()
428        };
429
430        let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
431        let filters = filters?;
432        Ok(filters.into_boxed_slice())
433    }
434
435    /// Records shader rendering commands to the provided command encoder.
436    pub fn frame<'a>(
437        &mut self,
438        input: &wgpu::Texture,
439        viewport: &Viewport<WgpuOutputView<'a>>,
440        cmd: &mut wgpu::CommandEncoder,
441        frame_count: usize,
442        options: Option<&FrameOptionsWgpu>,
443    ) -> error::Result<()> {
444        let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
445        let passes = &mut self.passes[0..max];
446
447        if let Some(options) = &options {
448            if options.clear_history {
449                for history in &mut self.history_framebuffers {
450                    history.clear(cmd);
451                }
452            }
453        }
454
455        if passes.is_empty() {
456            return Ok(());
457        }
458
459        let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
460
461        let filter = passes[0].meta.filter;
462        let wrap_mode = passes[0].meta.wrap_mode;
463
464        // update history
465        for (texture, image) in self
466            .common
467            .history_textures
468            .iter_mut()
469            .zip(self.history_framebuffers.iter())
470        {
471            *texture = Some(image.as_input(filter, wrap_mode));
472        }
473
474        let original = InputImage {
475            image: input.clone(),
476            view: original_image_view,
477            wrap_mode,
478            filter_mode: filter,
479            mip_filter: filter,
480        };
481
482        let mut source = original.clone();
483
484        let passes_len = passes.len();
485        let options = options.unwrap_or(&self.default_frame_options);
486
487        // swap output and feedback **before** recording command buffers
488        for index in 0..passes_len {
489            if self.feedback_framebuffers.contains(index) {
490                std::mem::swap(
491                    &mut self.output_framebuffers[index],
492                    &mut self.feedback_framebuffers[index],
493                );
494            }
495        }
496
497        let scale_context = self.common.device.clone();
498
499        // rescale feedback buffers and refresh their bound textures.
500        OwnedImage::scale_feedback_framebuffers_with_context(
501            source.image.size().into(),
502            viewport.output.size,
503            original.image.size().into(),
504            &mut self.feedback_framebuffers,
505            passes,
506            &scale_context,
507            |index, pass, feedback| {
508                self.common.feedback_textures[index] =
509                    Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
510                Ok(())
511            },
512        )?;
513
514        OwnedImage::scale_output_framebuffers_with_context(
515            source.image.size().into(),
516            viewport.output.size,
517            original.image.size().into(),
518            &mut self.output_framebuffers,
519            passes,
520            &scale_context,
521            |index, pass, target, size| {
522                source.filter_mode = pass.meta.filter;
523                source.wrap_mode = pass.meta.wrap_mode;
524                source.mip_filter = pass.meta.filter;
525                let frame_count_pass = pass.meta.get_frame_count(frame_count);
526
527                if index != passes_len - 1 {
528                    let output_image = WgpuOutputView::from(&*target);
529                    let out = RenderTarget::identity(&output_image)?;
530
531                    pass.draw(
532                        cmd,
533                        index,
534                        &self.common,
535                        frame_count_pass,
536                        options,
537                        viewport,
538                        &original,
539                        &source,
540                        &out,
541                        None,
542                        QuadType::Offscreen,
543                    )?;
544
545                    if target.max_miplevels > 1 && !self.disable_mipmaps {
546                        let sampler = self.common.samplers.get(
547                            WrapMode::ClampToEdge,
548                            FilterMode::Linear,
549                            FilterMode::Nearest,
550                        );
551
552                        target.generate_mipmaps(
553                            &self.common.device,
554                            cmd,
555                            &mut self.mipmapper,
556                            &sampler,
557                        );
558                    }
559
560                    self.common.output_textures[index] =
561                        Some(target.as_input(pass.meta.filter, pass.meta.wrap_mode));
562                    source = self.common.output_textures[index].clone().unwrap();
563                    return Ok(());
564                }
565
566                if !pass.graphics_pipeline.has_format(viewport.output.format) {
567                    // need to recompile
568                    pass.graphics_pipeline
569                        .recompile(&self.common.device, viewport.output.format);
570                }
571
572                // When feedback is enabled, render the last pass to the intermediate
573                // framebuffer first then render to the viewport with the OutputSize semantic
574                // overridden to the FB scale.
575                //
576                // Shaders need to see the pass's declared scale rather than the viewport size,
577                // or they won't render correctly for feedback.
578                let output_size_override = if self.draw_last_pass_feedback {
579                    let output_image = WgpuOutputView::from(&*target);
580                    let out = RenderTarget::viewport_with_output(&output_image, viewport);
581
582                    pass.draw(
583                        cmd,
584                        index,
585                        &self.common,
586                        frame_count_pass,
587                        options,
588                        viewport,
589                        &original,
590                        &source,
591                        &out,
592                        None,
593                        QuadType::Final,
594                    )?;
595                    Some(size)
596                } else {
597                    None
598                };
599
600                let out = RenderTarget::viewport(viewport);
601                pass.draw(
602                    cmd,
603                    index,
604                    &self.common,
605                    frame_count_pass,
606                    options,
607                    viewport,
608                    &original,
609                    &source,
610                    &out,
611                    output_size_override,
612                    QuadType::Final,
613                )?;
614
615                Ok(())
616            },
617        )?;
618
619        self.push_history(&input, cmd);
620        Ok(())
621    }
622}