librashader_runtime_wgpu/
filter_chain.rs

1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18use rayon::ThreadPoolBuilder;
19
20use crate::buffer::WgpuStagedBuffer;
21use crate::draw_quad::DrawQuad;
22use librashader_common::{FilterMode, Size, Viewport, WrapMode};
23use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
24use librashader_runtime::framebuffer::FramebufferInit;
25use librashader_runtime::render_target::RenderTarget;
26use librashader_runtime::scaling::ScaleFramebuffer;
27
28use crate::error;
29use crate::error::FilterChainError;
30use crate::filter_pass::FilterPass;
31use crate::framebuffer::WgpuOutputView;
32use crate::graphics_pipeline::WgpuGraphicsPipeline;
33use crate::luts::LutTexture;
34use crate::mipmap::MipmapGen;
35use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
36use crate::samplers::SamplerSet;
37use crate::texture::{InputImage, OwnedImage};
38
39mod compile {
40    use super::*;
41    use librashader_pack::{PassResource, TextureResource};
42
43    #[cfg(not(feature = "stable"))]
44    pub type ShaderPassMeta =
45        ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
46
47    #[cfg(feature = "stable")]
48    pub type ShaderPassMeta =
49        ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
50
51    #[cfg_attr(not(feature = "stable"), define_opaque(ShaderPassMeta))]
52    pub fn compile_passes(
53        shaders: Vec<PassResource>,
54        textures: &[TextureResource],
55    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
56        let (passes, semantics) = WGSL::compile_preset_passes::<
57            SpirvCompilation,
58            Naga,
59            FilterChainError,
60        >(shaders, textures.iter().map(|t| &t.meta))?;
61        Ok((passes, semantics))
62    }
63}
64
65use compile::{compile_passes, ShaderPassMeta};
66use librashader_pack::{ShaderPresetPack, TextureResource};
67use librashader_runtime::parameters::RuntimeParameters;
68
69/// A wgpu filter chain.
70pub struct FilterChainWgpu {
71    pub(crate) common: FilterCommon,
72    passes: Box<[FilterPass]>,
73    output_framebuffers: Box<[OwnedImage]>,
74    feedback_framebuffers: Box<[OwnedImage]>,
75    history_framebuffers: VecDeque<OwnedImage>,
76    disable_mipmaps: bool,
77    mipmapper: MipmapGen,
78    default_frame_options: FrameOptionsWgpu,
79    draw_last_pass_feedback: bool,
80}
81
82pub(crate) struct FilterCommon {
83    pub output_textures: Box<[Option<InputImage>]>,
84    pub feedback_textures: Box<[Option<InputImage>]>,
85    pub history_textures: Box<[Option<InputImage>]>,
86    pub luts: FastHashMap<usize, LutTexture>,
87    pub samplers: SamplerSet,
88    pub config: RuntimeParameters,
89    pub(crate) draw_quad: DrawQuad,
90    pub(crate) device: wgpu::Device,
91    pub(crate) queue: wgpu::Queue,
92}
93
94impl FilterChainWgpu {
95    /// Load the shader preset at the given path into a filter chain.
96    pub fn load_from_path(
97        path: impl AsRef<Path>,
98        features: ShaderFeatures,
99        device: &wgpu::Device,
100        queue: &wgpu::Queue,
101        options: Option<&FilterChainOptionsWgpu>,
102    ) -> error::Result<FilterChainWgpu> {
103        // load passes from preset
104        let preset = ShaderPreset::try_parse(path, features)?;
105
106        Self::load_from_preset(preset, device, queue, options)
107    }
108
109    /// Load a filter chain from a pre-parsed `ShaderPreset`.
110    pub fn load_from_preset(
111        preset: ShaderPreset,
112        device: &wgpu::Device,
113        queue: &wgpu::Queue,
114        options: Option<&FilterChainOptionsWgpu>,
115    ) -> error::Result<FilterChainWgpu> {
116        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
117        Self::load_from_pack(preset, device, queue, options)
118    }
119
120    /// Load a filter chain from a pre-parsed and loaded `ShaderPresetPack`.
121    pub fn load_from_pack(
122        preset: ShaderPresetPack,
123        device: &wgpu::Device,
124        queue: &wgpu::Queue,
125        options: Option<&FilterChainOptionsWgpu>,
126    ) -> error::Result<FilterChainWgpu> {
127        let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
128            label: Some("librashader load cmd"),
129        });
130        let filter_chain =
131            Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
132
133        let cmd = cmd.finish();
134
135        // Wait for device
136        let index = queue.submit([cmd]);
137        device.poll(wgpu::PollType::Wait {
138            submission_index: Some(index),
139            timeout: None,
140        })?;
141
142        Ok(filter_chain)
143    }
144
145    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
146    /// to the caller. This function therefore requires no external synchronization of the device queue.
147    ///
148    /// ## Safety
149    /// The provided command buffer must be ready for recording and contain no prior commands.
150    /// The caller is responsible for ending the command buffer and immediately submitting it to a
151    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
152    pub fn load_from_preset_deferred(
153        preset: ShaderPreset,
154        device: &wgpu::Device,
155        queue: &wgpu::Queue,
156        cmd: &mut wgpu::CommandEncoder,
157        options: Option<&FilterChainOptionsWgpu>,
158    ) -> error::Result<FilterChainWgpu> {
159        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
160        Self::load_from_pack_deferred(preset, device, queue, cmd, options)
161    }
162
163    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
164    /// to the caller. This function therefore requires no external synchronization of the device queue.
165    ///
166    /// ## Safety
167    /// The provided command buffer must be ready for recording and contain no prior commands.
168    /// The caller is responsible for ending the command buffer and immediately submitting it to a
169    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
170    pub fn load_from_pack_deferred(
171        preset: ShaderPresetPack,
172        device: &wgpu::Device,
173        queue: &wgpu::Queue,
174        cmd: &mut wgpu::CommandEncoder,
175        options: Option<&FilterChainOptionsWgpu>,
176    ) -> error::Result<FilterChainWgpu> {
177        let config = RuntimeParameters::new(&preset);
178
179        let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
180
181        // cache is opt-in for wgpu, not opt-out because of feature requirements.
182        let disable_cache = options.map_or(true, |o| !o.enable_cache);
183
184        // initialize passes
185        let filters = Self::init_passes(
186            &device,
187            passes,
188            &semantics,
189            options.and_then(|o| o.adapter_info.as_ref()),
190            disable_cache,
191        )?;
192
193        let samplers = SamplerSet::new(&device);
194        let mut mipmapper = MipmapGen::new(&device);
195        let luts = FilterChainWgpu::load_luts(
196            &device,
197            &queue,
198            cmd,
199            &mut mipmapper,
200            &samplers,
201            preset.textures,
202        )?;
203        //
204        let framebuffer_gen = || {
205            Ok::<_, error::FilterChainError>(OwnedImage::new(
206                &device,
207                Size::new(1, 1),
208                1,
209                wgpu::TextureFormat::Bgra8Unorm,
210            ))
211        };
212        let input_gen = || None;
213        let framebuffer_init = FramebufferInit::new(
214            filters.iter().map(|f| &f.reflection.meta),
215            &framebuffer_gen,
216            &input_gen,
217        );
218
219        //
220        // // initialize output framebuffers
221        let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
222        //
223        // initialize feedback framebuffers
224        let (feedback_framebuffers, feedback_textures) =
225            framebuffer_init.init_output_framebuffers()?;
226        //
227        // initialize history
228        let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
229
230        let draw_quad = DrawQuad::new(&device);
231
232        Ok(FilterChainWgpu {
233            draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
234            common: FilterCommon {
235                luts,
236                samplers,
237                config,
238                draw_quad,
239                device: device.clone(),
240                queue: queue.clone(),
241                output_textures,
242                feedback_textures,
243                history_textures,
244            },
245            passes: filters,
246            output_framebuffers,
247            feedback_framebuffers,
248            history_framebuffers,
249            disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
250            mipmapper,
251            default_frame_options: Default::default(),
252        })
253    }
254
255    fn load_luts(
256        device: &wgpu::Device,
257        queue: &wgpu::Queue,
258        cmd: &mut wgpu::CommandEncoder,
259        mipmapper: &mut MipmapGen,
260        sampler_set: &SamplerSet,
261        textures: Vec<TextureResource>,
262    ) -> error::Result<FastHashMap<usize, LutTexture>> {
263        let mut luts = FastHashMap::default();
264
265        #[cfg(not(target_arch = "wasm32"))]
266        let images_iter = textures.into_par_iter();
267
268        #[cfg(target_arch = "wasm32")]
269        let images_iter = textures.into_iter();
270
271        let textures = images_iter
272            .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
273            .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
274        for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
275            let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
276            luts.insert(index, texture);
277        }
278        Ok(luts)
279    }
280
281    fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
282        if let Some(mut back) = self.history_framebuffers.pop_back() {
283            if back.image.size() != input.size() || input.format() != back.image.format() {
284                // old back will get dropped.. do we need to defer?
285                let _old_back = std::mem::replace(
286                    &mut back,
287                    OwnedImage::new(
288                        &self.common.device,
289                        input.size().into(),
290                        1,
291                        input.format().into(),
292                    ),
293                );
294            }
295
296            back.copy_from(&self.common.device, cmd, input);
297
298            self.history_framebuffers.push_front(back)
299        }
300    }
301
302    fn init_passes(
303        device: &wgpu::Device,
304        passes: Vec<ShaderPassMeta>,
305        semantics: &ShaderSemantics,
306        adapter_info: Option<&wgpu::AdapterInfo>,
307        disable_cache: bool,
308    ) -> error::Result<Box<[FilterPass]>> {
309        #[cfg(not(target_arch = "wasm32"))]
310        let filter_creation_fn = || {
311            let passes_iter = passes.into_par_iter();
312            #[cfg(target_arch = "wasm32")]
313            let passes_iter = passes.into_iter();
314
315            let filters: Vec<error::Result<FilterPass>> = passes_iter
316                .enumerate()
317                .map(|(index, (config, mut reflect))| {
318                    let reflection = reflect.reflect(index, semantics)?;
319                    let wgsl = reflect.compile(NagaLoweringOptions {
320                        write_pcb_as_ubo: true,
321                        sampler_bind_group: 1,
322                    })?;
323
324                    let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
325                    let push_size = reflection
326                        .push_constant
327                        .as_ref()
328                        .map_or(0, |push| push.size as wgpu::BufferAddress);
329
330                    let uniform_storage = UniformStorage::new_with_storage(
331                        WgpuStagedBuffer::new(
332                            &device,
333                            wgpu::BufferUsages::UNIFORM,
334                            ubo_size as wgpu::BufferAddress,
335                            Some("ubo"),
336                        ),
337                        WgpuStagedBuffer::new(
338                            &device,
339                            wgpu::BufferUsages::UNIFORM,
340                            push_size as wgpu::BufferAddress,
341                            Some("push"),
342                        ),
343                    );
344
345                    let uniform_bindings =
346                        reflection.meta.create_binding_map(|param| param.offset());
347
348                    let render_pass_format: Option<wgpu::TextureFormat> =
349                        if let Some(format) = config.meta.get_format_override() {
350                            format.into()
351                        } else {
352                            config.data.format.into()
353                        };
354
355                    let graphics_pipeline = WgpuGraphicsPipeline::new(
356                        &device,
357                        &wgsl,
358                        &reflection,
359                        render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
360                        adapter_info,
361                        disable_cache,
362                    );
363
364                    Ok(FilterPass {
365                        reflection,
366                        uniform_storage,
367                        uniform_bindings,
368                        source: config.data,
369                        meta: config.meta,
370                        graphics_pipeline,
371                    })
372                })
373                .collect();
374            filters
375        };
376
377        #[cfg(target_arch = "wasm32")]
378        let filters = filter_creation_fn();
379
380        #[cfg(not(target_arch = "wasm32"))]
381        let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
382            // naga compilations can possibly use degenerate stack sizes.
383            .stack_size(10 * 1048576)
384            .build()
385        {
386            thread_pool.install(|| filter_creation_fn())
387        } else {
388            filter_creation_fn()
389        };
390
391        let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
392        let filters = filters?;
393        Ok(filters.into_boxed_slice())
394    }
395
396    /// Records shader rendering commands to the provided command encoder.
397    pub fn frame<'a>(
398        &mut self,
399        input: &wgpu::Texture,
400        viewport: &Viewport<WgpuOutputView<'a>>,
401        cmd: &mut wgpu::CommandEncoder,
402        frame_count: usize,
403        options: Option<&FrameOptionsWgpu>,
404    ) -> error::Result<()> {
405        let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
406        let passes = &mut self.passes[0..max];
407
408        if let Some(options) = &options {
409            if options.clear_history {
410                for history in &mut self.history_framebuffers {
411                    history.clear(cmd);
412                }
413            }
414        }
415
416        if passes.is_empty() {
417            return Ok(());
418        }
419
420        let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
421
422        let filter = passes[0].meta.filter;
423        let wrap_mode = passes[0].meta.wrap_mode;
424
425        // update history
426        for (texture, image) in self
427            .common
428            .history_textures
429            .iter_mut()
430            .zip(self.history_framebuffers.iter())
431        {
432            *texture = Some(image.as_input(filter, wrap_mode));
433        }
434
435        let original = InputImage {
436            image: input.clone(),
437            view: original_image_view,
438            wrap_mode,
439            filter_mode: filter,
440            mip_filter: filter,
441        };
442
443        let mut source = original.clone();
444
445        // swap output and feedback **before** recording command buffers
446        std::mem::swap(
447            &mut self.output_framebuffers,
448            &mut self.feedback_framebuffers,
449        );
450
451        // rescale render buffers to ensure all bindings are valid.
452        OwnedImage::scale_framebuffers_with_context(
453            source.image.size().into(),
454            viewport.output.size,
455            original.image.size().into(),
456            &mut self.output_framebuffers,
457            &mut self.feedback_framebuffers,
458            passes,
459            &self.common.device,
460            Some(&mut |index: usize,
461                       pass: &FilterPass,
462                       output: &OwnedImage,
463                       feedback: &OwnedImage| {
464                // refresh inputs
465                self.common.feedback_textures[index] =
466                    Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
467                self.common.output_textures[index] =
468                    Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode));
469                Ok(())
470            }),
471        )?;
472
473        let passes_len = passes.len();
474        let (pass, last) = passes.split_at_mut(passes_len - 1);
475
476        let options = options.unwrap_or(&self.default_frame_options);
477
478        for (index, pass) in pass.iter_mut().enumerate() {
479            source.filter_mode = pass.meta.filter;
480            source.wrap_mode = pass.meta.wrap_mode;
481            source.mip_filter = pass.meta.filter;
482
483            let target = &self.output_framebuffers[index];
484            let output_image = WgpuOutputView::from(target);
485            let out = RenderTarget::identity(&output_image)?;
486
487            pass.draw(
488                cmd,
489                index,
490                &self.common,
491                pass.meta.get_frame_count(frame_count),
492                options,
493                viewport,
494                &original,
495                &source,
496                &out,
497                QuadType::Offscreen,
498            )?;
499
500            if target.max_miplevels > 1 && !self.disable_mipmaps {
501                let sampler = self.common.samplers.get(
502                    WrapMode::ClampToEdge,
503                    FilterMode::Linear,
504                    FilterMode::Nearest,
505                );
506
507                target.generate_mipmaps(&self.common.device, cmd, &mut self.mipmapper, &sampler);
508            }
509
510            source = self.common.output_textures[index].clone().unwrap();
511        }
512
513        // try to hint the optimizer
514        assert_eq!(last.len(), 1);
515
516        if let Some(pass) = last.iter_mut().next() {
517            let index = passes_len - 1;
518            if !pass.graphics_pipeline.has_format(viewport.output.format) {
519                // need to recompile
520                pass.graphics_pipeline
521                    .recompile(&self.common.device, viewport.output.format);
522            }
523
524            source.filter_mode = pass.meta.filter;
525            source.wrap_mode = pass.meta.wrap_mode;
526            source.mip_filter = pass.meta.filter;
527
528            if self.draw_last_pass_feedback {
529                let target = &self.output_framebuffers[index];
530                let output_image = WgpuOutputView::from(target);
531                let out = RenderTarget::viewport_with_output(&output_image, viewport);
532
533                pass.draw(
534                    cmd,
535                    index,
536                    &self.common,
537                    pass.meta.get_frame_count(frame_count),
538                    options,
539                    viewport,
540                    &original,
541                    &source,
542                    &out,
543                    QuadType::Final,
544                )?;
545            }
546
547            let out = RenderTarget::viewport(viewport);
548            pass.draw(
549                cmd,
550                index,
551                &self.common,
552                pass.meta.get_frame_count(frame_count),
553                options,
554                viewport,
555                &original,
556                &source,
557                &out,
558                QuadType::Final,
559            )?;
560        }
561
562        self.push_history(&input, cmd);
563        Ok(())
564    }
565}