Skip to main content

librashader_runtime_wgpu/
filter_chain.rs

1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::ThreadPoolBuilder;
20
21use crate::buffer::WgpuStagedBuffer;
22use crate::draw_quad::DrawQuad;
23use librashader_common::{FilterMode, Size, Viewport, WrapMode};
24use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
25use librashader_runtime::framebuffer::{FramebufferInit, FramebufferPool};
26use librashader_runtime::render_target::RenderTarget;
27use librashader_runtime::scaling::ScaleFramebuffer;
28
29use crate::error;
30use crate::error::FilterChainError;
31use crate::filter_pass::FilterPass;
32use crate::framebuffer::WgpuOutputView;
33use crate::graphics_pipeline::WgpuGraphicsPipeline;
34use crate::luts::LutTexture;
35use crate::mipmap::MipmapGen;
36use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
37use crate::samplers::SamplerSet;
38use crate::texture::{InputImage, OwnedImage};
39
40#[cfg(feature = "native")]
41mod compile {
42    use super::*;
43    use librashader_pack::{PassResource, TextureResource};
44
45    #[cfg(feature = "nightly")]
46    pub type ShaderPassMeta =
47        ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
48
49    #[cfg(not(feature = "nightly"))]
50    pub type ShaderPassMeta =
51        ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
52
53    #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
54    pub fn compile_passes(
55        shaders: Vec<PassResource>,
56        textures: &[TextureResource],
57    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
58        let (passes, semantics) = WGSL::compile_preset_passes::<
59            SpirvCompilation,
60            Naga,
61            FilterChainError,
62        >(shaders, textures.iter().map(|t| &t.meta))?;
63        Ok((passes, semantics))
64    }
65}
66
67#[cfg(feature = "wgsl_preset_pack")]
68mod compile_wgsl {
69    use super::*;
70    use librashader_pack::{PassResource, TextureResource};
71    use librashader_reflect::front::WgslCompilation;
72
73    #[cfg(feature = "nightly")]
74    pub type ShaderPassMeta =
75        ShaderPassArtifact<impl CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>;
76
77    #[cfg(not(feature = "nightly"))]
78    pub type ShaderPassMeta =
79        ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>>;
80
81    #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
82    pub fn compile_passes(
83        shaders: Vec<PassResource>,
84        textures: &[TextureResource],
85    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
86        let (passes, semantics) = WGSL::compile_preset_passes::<
87            WgslCompilation,
88            Naga,
89            FilterChainError,
90        >(shaders, textures.iter().map(|t| &t.meta))?;
91        Ok((passes, semantics))
92    }
93}
94
95#[cfg(all(feature = "native", not(feature = "wgsl_preset_pack")))]
96use compile::{compile_passes, ShaderPassMeta};
97
98#[cfg(any(not(feature = "native"), feature = "wgsl_preset_pack"))]
99use compile_wgsl::{compile_passes, ShaderPassMeta};
100
101use librashader_pack::{ShaderPresetPack, TextureResource};
102use librashader_runtime::parameters::RuntimeParameters;
103
104/// A wgpu filter chain.
105pub struct FilterChainWgpu {
106    pub(crate) common: FilterCommon,
107    passes: Box<[FilterPass]>,
108    output_framebuffers: FramebufferPool<OwnedImage>,
109    feedback_framebuffers: FramebufferPool<OwnedImage>,
110    history_framebuffers: VecDeque<OwnedImage>,
111    disable_mipmaps: bool,
112    mipmapper: MipmapGen,
113    default_frame_options: FrameOptionsWgpu,
114    draw_last_pass_feedback: bool,
115}
116
117pub(crate) struct FilterCommon {
118    pub output_textures: Box<[Option<InputImage>]>,
119    pub feedback_textures: Box<[Option<InputImage>]>,
120    pub history_textures: Box<[Option<InputImage>]>,
121    pub luts: FastHashMap<usize, LutTexture>,
122    pub samplers: SamplerSet,
123    pub config: RuntimeParameters,
124    pub(crate) draw_quad: DrawQuad,
125    pub(crate) device: wgpu::Device,
126    pub(crate) queue: wgpu::Queue,
127}
128
129impl FilterChainWgpu {
130    /// Load the shader preset at the given path into a filter chain.
131    #[cfg(feature = "native")]
132    pub fn load_from_path(
133        path: impl AsRef<Path>,
134        features: ShaderFeatures,
135        device: &wgpu::Device,
136        queue: &wgpu::Queue,
137        options: Option<&FilterChainOptionsWgpu>,
138    ) -> error::Result<FilterChainWgpu> {
139        // load passes from preset
140        let preset = ShaderPreset::try_parse(path, features)?;
141
142        Self::load_from_preset(preset, device, queue, options)
143    }
144
145    /// Load a filter chain from a pre-parsed `ShaderPreset`.
146    #[cfg(feature = "native")]
147    pub fn load_from_preset(
148        preset: ShaderPreset,
149        device: &wgpu::Device,
150        queue: &wgpu::Queue,
151        options: Option<&FilterChainOptionsWgpu>,
152    ) -> error::Result<FilterChainWgpu> {
153        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
154        Self::load_from_pack(preset, device, queue, options)
155    }
156
157    /// Load a filter chain from a pre-parsed and loaded `ShaderPresetPack`.
158    pub fn load_from_pack(
159        preset: ShaderPresetPack,
160        device: &wgpu::Device,
161        queue: &wgpu::Queue,
162        options: Option<&FilterChainOptionsWgpu>,
163    ) -> error::Result<FilterChainWgpu> {
164        let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
165            label: Some("librashader load cmd"),
166        });
167        let filter_chain =
168            Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
169
170        let cmd = cmd.finish();
171
172        // Wait for device
173        let index = queue.submit([cmd]);
174        device.poll(wgpu::PollType::Wait {
175            submission_index: Some(index),
176            timeout: None,
177        })?;
178
179        Ok(filter_chain)
180    }
181
182    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
183    /// to the caller. This function therefore requires no external synchronization of the device queue.
184    ///
185    /// ## Safety
186    /// The provided command buffer must be ready for recording and contain no prior commands.
187    /// The caller is responsible for ending the command buffer and immediately submitting it to a
188    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
189    #[cfg(feature = "native")]
190    pub fn load_from_preset_deferred(
191        preset: ShaderPreset,
192        device: &wgpu::Device,
193        queue: &wgpu::Queue,
194        cmd: &mut wgpu::CommandEncoder,
195        options: Option<&FilterChainOptionsWgpu>,
196    ) -> error::Result<FilterChainWgpu> {
197        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
198        Self::load_from_pack_deferred(preset, device, queue, cmd, options)
199    }
200
201    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
202    /// to the caller. This function therefore requires no external synchronization of the device queue.
203    ///
204    /// ## Safety
205    /// The provided command buffer must be ready for recording and contain no prior commands.
206    /// The caller is responsible for ending the command buffer and immediately submitting it to a
207    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
208    pub fn load_from_pack_deferred(
209        preset: ShaderPresetPack,
210        device: &wgpu::Device,
211        queue: &wgpu::Queue,
212        cmd: &mut wgpu::CommandEncoder,
213        options: Option<&FilterChainOptionsWgpu>,
214    ) -> error::Result<FilterChainWgpu> {
215        let config = RuntimeParameters::new(&preset);
216
217        let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
218
219        // cache is opt-in for wgpu, not opt-out because of feature requirements.
220        let disable_cache = options.map_or(true, |o| !o.enable_cache);
221
222        // initialize passes
223        let filters = Self::init_passes(
224            &device,
225            passes,
226            &semantics,
227            options.and_then(|o| o.adapter_info.as_ref()),
228            disable_cache,
229        )?;
230
231        let samplers = SamplerSet::new(&device);
232        let mut mipmapper = MipmapGen::new(&device);
233        let luts = FilterChainWgpu::load_luts(
234            &device,
235            &queue,
236            cmd,
237            &mut mipmapper,
238            &samplers,
239            preset.textures,
240        )?;
241        //
242        let framebuffer_gen = || {
243            Ok::<_, error::FilterChainError>(OwnedImage::new(
244                &device,
245                Size::new(1, 1),
246                1,
247                wgpu::TextureFormat::Bgra8Unorm,
248            ))
249        };
250        let input_gen = || None;
251        let framebuffer_init = FramebufferInit::new(
252            filters.iter().map(|f| &f.reflection.meta),
253            &framebuffer_gen,
254            &input_gen,
255        );
256
257        //
258        // // initialize output framebuffers
259        let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
260        //
261        // initialize feedback framebuffers
262        let (feedback_framebuffers, feedback_textures) = framebuffer_init.init_feedback_framebuffers()?;
263        //
264        // initialize history
265        let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
266
267        let draw_quad = DrawQuad::new(&device);
268
269        Ok(FilterChainWgpu {
270            draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
271            common: FilterCommon {
272                luts,
273                samplers,
274                config,
275                draw_quad,
276                device: device.clone(),
277                queue: queue.clone(),
278                output_textures,
279                feedback_textures,
280                history_textures,
281            },
282            passes: filters,
283            output_framebuffers,
284            feedback_framebuffers,
285            history_framebuffers,
286            disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
287            mipmapper,
288            default_frame_options: Default::default(),
289        })
290    }
291
292    fn load_luts(
293        device: &wgpu::Device,
294        queue: &wgpu::Queue,
295        cmd: &mut wgpu::CommandEncoder,
296        mipmapper: &mut MipmapGen,
297        sampler_set: &SamplerSet,
298        textures: Vec<TextureResource>,
299    ) -> error::Result<FastHashMap<usize, LutTexture>> {
300        let mut luts = FastHashMap::default();
301
302        #[cfg(not(target_arch = "wasm32"))]
303        let images_iter = textures.into_par_iter();
304
305        #[cfg(target_arch = "wasm32")]
306        let images_iter = textures.into_iter();
307
308        let textures = images_iter
309            .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
310            .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
311        for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
312            let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
313            luts.insert(index, texture);
314        }
315        Ok(luts)
316    }
317
318    fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
319        if let Some(mut back) = self.history_framebuffers.pop_back() {
320            if back.image.size() != input.size() || input.format() != back.image.format() {
321                // old back will get dropped.. do we need to defer?
322                let _old_back = std::mem::replace(
323                    &mut back,
324                    OwnedImage::new(
325                        &self.common.device,
326                        input.size().into(),
327                        1,
328                        input.format().into(),
329                    ),
330                );
331            }
332
333            back.copy_from(&self.common.device, cmd, input);
334
335            self.history_framebuffers.push_front(back)
336        }
337    }
338
339    fn init_passes(
340        device: &wgpu::Device,
341        passes: Vec<ShaderPassMeta>,
342        semantics: &ShaderSemantics,
343        adapter_info: Option<&wgpu::AdapterInfo>,
344        disable_cache: bool,
345    ) -> error::Result<Box<[FilterPass]>> {
346        let filter_creation_fn = || {
347            #[cfg(not(target_arch = "wasm32"))]
348            let passes_iter = passes.into_par_iter();
349            #[cfg(target_arch = "wasm32")]
350            let passes_iter = passes.into_iter();
351
352            let filters: Vec<error::Result<FilterPass>> = passes_iter
353                .enumerate()
354                .map(|(index, (config, mut reflect))| {
355                    let reflection = reflect.reflect(index, semantics)?;
356                    let wgsl = reflect.compile(NagaLoweringOptions {
357                        write_pcb_as_ubo: true,
358                        sampler_bind_group: 1,
359                        suppress_derivative_uniformity: true,
360                    })?;
361
362                    let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
363                    let push_size = reflection
364                        .push_constant
365                        .as_ref()
366                        .map_or(0, |push| push.size as wgpu::BufferAddress);
367
368                    let uniform_storage = UniformStorage::new_with_storage(
369                        WgpuStagedBuffer::new(
370                            &device,
371                            wgpu::BufferUsages::UNIFORM,
372                            ubo_size as wgpu::BufferAddress,
373                            Some("ubo"),
374                        ),
375                        WgpuStagedBuffer::new(
376                            &device,
377                            wgpu::BufferUsages::UNIFORM,
378                            push_size as wgpu::BufferAddress,
379                            Some("push"),
380                        ),
381                    );
382
383                    let uniform_bindings =
384                        reflection.meta.create_binding_map(|param| param.offset());
385
386                    let render_pass_format: Option<wgpu::TextureFormat> =
387                        if let Some(format) = config.meta.get_format_override() {
388                            format.into()
389                        } else {
390                            config.data.format.into()
391                        };
392
393                    let graphics_pipeline = WgpuGraphicsPipeline::new(
394                        &device,
395                        &wgsl,
396                        &reflection,
397                        render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
398                        adapter_info,
399                        disable_cache,
400                    );
401
402                    Ok(FilterPass {
403                        reflection,
404                        uniform_storage,
405                        uniform_bindings,
406                        source: config.data,
407                        meta: config.meta,
408                        graphics_pipeline,
409                    })
410                })
411                .collect();
412            filters
413        };
414
415        #[cfg(target_arch = "wasm32")]
416        let filters = filter_creation_fn();
417
418        #[cfg(not(target_arch = "wasm32"))]
419        let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
420            // naga compilations can possibly use degenerate stack sizes.
421            .stack_size(10 * 1048576)
422            .build()
423        {
424            thread_pool.install(|| filter_creation_fn())
425        } else {
426            filter_creation_fn()
427        };
428
429        let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
430        let filters = filters?;
431        Ok(filters.into_boxed_slice())
432    }
433
434    /// Records shader rendering commands to the provided command encoder.
435    pub fn frame<'a>(
436        &mut self,
437        input: &wgpu::Texture,
438        viewport: &Viewport<WgpuOutputView<'a>>,
439        cmd: &mut wgpu::CommandEncoder,
440        frame_count: usize,
441        options: Option<&FrameOptionsWgpu>,
442    ) -> error::Result<()> {
443        let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
444        let passes = &mut self.passes[0..max];
445
446        if let Some(options) = &options {
447            if options.clear_history {
448                for history in &mut self.history_framebuffers {
449                    history.clear(cmd);
450                }
451            }
452        }
453
454        if passes.is_empty() {
455            return Ok(());
456        }
457
458        let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
459
460        let filter = passes[0].meta.filter;
461        let wrap_mode = passes[0].meta.wrap_mode;
462
463        // update history
464        for (texture, image) in self
465            .common
466            .history_textures
467            .iter_mut()
468            .zip(self.history_framebuffers.iter())
469        {
470            *texture = Some(image.as_input(filter, wrap_mode));
471        }
472
473        let original = InputImage {
474            image: input.clone(),
475            view: original_image_view,
476            wrap_mode,
477            filter_mode: filter,
478            mip_filter: filter,
479        };
480
481        let mut source = original.clone();
482
483        let passes_len = passes.len();
484        let options = options.unwrap_or(&self.default_frame_options);
485
486        // swap output and feedback **before** recording command buffers
487        for index in 0..passes_len {
488            if self.feedback_framebuffers.contains(index) {
489                std::mem::swap(&mut self.output_framebuffers[index], &mut self.feedback_framebuffers[index]);
490            }
491        }
492
493        let scale_context = self.common.device.clone();
494
495        // rescale feedback buffers and refresh their bound textures.
496        OwnedImage::scale_feedback_framebuffers_with_context(
497            source.image.size().into(),
498            viewport.output.size,
499            original.image.size().into(),
500            &mut self.feedback_framebuffers,
501            passes,
502            &scale_context,
503            |index, pass, feedback| {
504                self.common.feedback_textures[index] =
505                    Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
506                Ok(())
507            },
508        )?;
509
510        OwnedImage::scale_output_framebuffers_with_context(
511            source.image.size().into(),
512            viewport.output.size,
513            original.image.size().into(),
514            &mut self.output_framebuffers,
515            passes,
516            &scale_context,
517            |index, pass, target, size| {
518                source.filter_mode = pass.meta.filter;
519                source.wrap_mode = pass.meta.wrap_mode;
520                source.mip_filter = pass.meta.filter;
521                let frame_count_pass = pass.meta.get_frame_count(frame_count);
522
523                if index != passes_len - 1 {
524                    let output_image = WgpuOutputView::from(&*target);
525                    let out = RenderTarget::identity(&output_image)?;
526
527                    pass.draw(
528                        cmd,
529                        index,
530                        &self.common,
531                        frame_count_pass,
532                        options,
533                        viewport,
534                        &original,
535                        &source,
536                        &out,
537                        None,
538                        QuadType::Offscreen,
539                    )?;
540
541                    if target.max_miplevels > 1 && !self.disable_mipmaps {
542                        let sampler = self.common.samplers.get(
543                            WrapMode::ClampToEdge,
544                            FilterMode::Linear,
545                            FilterMode::Nearest,
546                        );
547
548                        target.generate_mipmaps(
549                            &self.common.device,
550                            cmd,
551                            &mut self.mipmapper,
552                            &sampler,
553                        );
554                    }
555
556                    self.common.output_textures[index] =
557                        Some(target.as_input(pass.meta.filter, pass.meta.wrap_mode));
558                    source = self.common.output_textures[index].clone().unwrap();
559                    return Ok(());
560                }
561
562                if !pass.graphics_pipeline.has_format(viewport.output.format) {
563                    // need to recompile
564                    pass.graphics_pipeline
565                        .recompile(&self.common.device, viewport.output.format);
566                }
567
568                // When feedback is enabled, render the last pass to the intermediate
569                // framebuffer first then render to the viewport with the OutputSize semantic
570                // overridden to the FB scale.
571                //
572                // Shaders need to see the pass's declared scale rather than the viewport size,
573                // or they won't render correctly for feedback.
574                let output_size_override = if self.draw_last_pass_feedback {
575                    let output_image = WgpuOutputView::from(&*target);
576                    let out = RenderTarget::viewport_with_output(&output_image, viewport);
577
578                    pass.draw(
579                        cmd,
580                        index,
581                        &self.common,
582                        frame_count_pass,
583                        options,
584                        viewport,
585                        &original,
586                        &source,
587                        &out,
588                        None,
589                        QuadType::Final,
590                    )?;
591                    Some(size)
592                } else {
593                    None
594                };
595
596                let out = RenderTarget::viewport(viewport);
597                pass.draw(
598                    cmd,
599                    index,
600                    &self.common,
601                    frame_count_pass,
602                    options,
603                    viewport,
604                    &original,
605                    &source,
606                    &out,
607                    output_size_override,
608                    QuadType::Final,
609                )?;
610
611                Ok(())
612            },
613        )?;
614
615        self.push_history(&input, cmd);
616        Ok(())
617    }
618}