1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18use rayon::ThreadPoolBuilder;
19
20use crate::buffer::WgpuStagedBuffer;
21use crate::draw_quad::DrawQuad;
22use librashader_common::{FilterMode, Size, Viewport, WrapMode};
23use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
24use librashader_runtime::framebuffer::FramebufferInit;
25use librashader_runtime::render_target::RenderTarget;
26use librashader_runtime::scaling::ScaleFramebuffer;
27
28use crate::error;
29use crate::error::FilterChainError;
30use crate::filter_pass::FilterPass;
31use crate::framebuffer::WgpuOutputView;
32use crate::graphics_pipeline::WgpuGraphicsPipeline;
33use crate::luts::LutTexture;
34use crate::mipmap::MipmapGen;
35use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
36use crate::samplers::SamplerSet;
37use crate::texture::{InputImage, OwnedImage};
38
39mod compile {
40 use super::*;
41 use librashader_pack::{PassResource, TextureResource};
42
43 #[cfg(not(feature = "stable"))]
44 pub type ShaderPassMeta =
45 ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
46
47 #[cfg(feature = "stable")]
48 pub type ShaderPassMeta =
49 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
50
51 #[cfg_attr(not(feature = "stable"), define_opaque(ShaderPassMeta))]
52 pub fn compile_passes(
53 shaders: Vec<PassResource>,
54 textures: &[TextureResource],
55 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
56 let (passes, semantics) = WGSL::compile_preset_passes::<
57 SpirvCompilation,
58 Naga,
59 FilterChainError,
60 >(shaders, textures.iter().map(|t| &t.meta))?;
61 Ok((passes, semantics))
62 }
63}
64
65use compile::{compile_passes, ShaderPassMeta};
66use librashader_pack::{ShaderPresetPack, TextureResource};
67use librashader_runtime::parameters::RuntimeParameters;
68
69pub struct FilterChainWgpu {
71 pub(crate) common: FilterCommon,
72 passes: Box<[FilterPass]>,
73 output_framebuffers: Box<[OwnedImage]>,
74 feedback_framebuffers: Box<[OwnedImage]>,
75 history_framebuffers: VecDeque<OwnedImage>,
76 disable_mipmaps: bool,
77 mipmapper: MipmapGen,
78 default_frame_options: FrameOptionsWgpu,
79 draw_last_pass_feedback: bool,
80}
81
82pub(crate) struct FilterCommon {
83 pub output_textures: Box<[Option<InputImage>]>,
84 pub feedback_textures: Box<[Option<InputImage>]>,
85 pub history_textures: Box<[Option<InputImage>]>,
86 pub luts: FastHashMap<usize, LutTexture>,
87 pub samplers: SamplerSet,
88 pub config: RuntimeParameters,
89 pub(crate) draw_quad: DrawQuad,
90 pub(crate) device: wgpu::Device,
91 pub(crate) queue: wgpu::Queue,
92}
93
94impl FilterChainWgpu {
95 pub fn load_from_path(
97 path: impl AsRef<Path>,
98 features: ShaderFeatures,
99 device: &wgpu::Device,
100 queue: &wgpu::Queue,
101 options: Option<&FilterChainOptionsWgpu>,
102 ) -> error::Result<FilterChainWgpu> {
103 let preset = ShaderPreset::try_parse(path, features)?;
105
106 Self::load_from_preset(preset, device, queue, options)
107 }
108
109 pub fn load_from_preset(
111 preset: ShaderPreset,
112 device: &wgpu::Device,
113 queue: &wgpu::Queue,
114 options: Option<&FilterChainOptionsWgpu>,
115 ) -> error::Result<FilterChainWgpu> {
116 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
117 Self::load_from_pack(preset, device, queue, options)
118 }
119
120 pub fn load_from_pack(
122 preset: ShaderPresetPack,
123 device: &wgpu::Device,
124 queue: &wgpu::Queue,
125 options: Option<&FilterChainOptionsWgpu>,
126 ) -> error::Result<FilterChainWgpu> {
127 let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
128 label: Some("librashader load cmd"),
129 });
130 let filter_chain =
131 Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
132
133 let cmd = cmd.finish();
134
135 let index = queue.submit([cmd]);
137 device.poll(wgpu::Maintain::WaitForSubmissionIndex(index));
138
139 Ok(filter_chain)
140 }
141
142 pub fn load_from_preset_deferred(
150 preset: ShaderPreset,
151 device: &wgpu::Device,
152 queue: &wgpu::Queue,
153 cmd: &mut wgpu::CommandEncoder,
154 options: Option<&FilterChainOptionsWgpu>,
155 ) -> error::Result<FilterChainWgpu> {
156 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
157 Self::load_from_pack_deferred(preset, device, queue, cmd, options)
158 }
159
160 pub fn load_from_pack_deferred(
168 preset: ShaderPresetPack,
169 device: &wgpu::Device,
170 queue: &wgpu::Queue,
171 cmd: &mut wgpu::CommandEncoder,
172 options: Option<&FilterChainOptionsWgpu>,
173 ) -> error::Result<FilterChainWgpu> {
174 let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
175
176 let disable_cache = options.map_or(true, |o| !o.enable_cache);
178
179 let filters = Self::init_passes(
181 &device,
182 passes,
183 &semantics,
184 options.and_then(|o| o.adapter_info.as_ref()),
185 disable_cache,
186 )?;
187
188 let samplers = SamplerSet::new(&device);
189 let mut mipmapper = MipmapGen::new(&device);
190 let luts = FilterChainWgpu::load_luts(
191 &device,
192 &queue,
193 cmd,
194 &mut mipmapper,
195 &samplers,
196 preset.textures,
197 )?;
198 let framebuffer_gen = || {
200 Ok::<_, error::FilterChainError>(OwnedImage::new(
201 &device,
202 Size::new(1, 1),
203 1,
204 wgpu::TextureFormat::Bgra8Unorm,
205 ))
206 };
207 let input_gen = || None;
208 let framebuffer_init = FramebufferInit::new(
209 filters.iter().map(|f| &f.reflection.meta),
210 &framebuffer_gen,
211 &input_gen,
212 );
213
214 let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
217 let (feedback_framebuffers, feedback_textures) =
220 framebuffer_init.init_output_framebuffers()?;
221 let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
224
225 let draw_quad = DrawQuad::new(&device);
226
227 Ok(FilterChainWgpu {
228 draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
229 common: FilterCommon {
230 luts,
231 samplers,
232 config: RuntimeParameters::new(preset.pass_count as usize, preset.parameters),
233 draw_quad,
234 device: device.clone(),
235 queue: queue.clone(),
236 output_textures,
237 feedback_textures,
238 history_textures,
239 },
240 passes: filters,
241 output_framebuffers,
242 feedback_framebuffers,
243 history_framebuffers,
244 disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
245 mipmapper,
246 default_frame_options: Default::default(),
247 })
248 }
249
250 fn load_luts(
251 device: &wgpu::Device,
252 queue: &wgpu::Queue,
253 cmd: &mut wgpu::CommandEncoder,
254 mipmapper: &mut MipmapGen,
255 sampler_set: &SamplerSet,
256 textures: Vec<TextureResource>,
257 ) -> error::Result<FastHashMap<usize, LutTexture>> {
258 let mut luts = FastHashMap::default();
259
260 #[cfg(not(target_arch = "wasm32"))]
261 let images_iter = textures.into_par_iter();
262
263 #[cfg(target_arch = "wasm32")]
264 let images_iter = textures.into_iter();
265
266 let textures = images_iter
267 .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
268 .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
269 for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
270 let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
271 luts.insert(index, texture);
272 }
273 Ok(luts)
274 }
275
276 fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
277 if let Some(mut back) = self.history_framebuffers.pop_back() {
278 if back.image.size() != input.size() || input.format() != back.image.format() {
279 let _old_back = std::mem::replace(
281 &mut back,
282 OwnedImage::new(
283 &self.common.device,
284 input.size().into(),
285 1,
286 input.format().into(),
287 ),
288 );
289 }
290
291 back.copy_from(&self.common.device, cmd, input);
292
293 self.history_framebuffers.push_front(back)
294 }
295 }
296
297 fn init_passes(
298 device: &wgpu::Device,
299 passes: Vec<ShaderPassMeta>,
300 semantics: &ShaderSemantics,
301 adapter_info: Option<&wgpu::AdapterInfo>,
302 disable_cache: bool,
303 ) -> error::Result<Box<[FilterPass]>> {
304 #[cfg(not(target_arch = "wasm32"))]
305 let filter_creation_fn = || {
306 let passes_iter = passes.into_par_iter();
307 #[cfg(target_arch = "wasm32")]
308 let passes_iter = passes.into_iter();
309
310 let filters: Vec<error::Result<FilterPass>> = passes_iter
311 .enumerate()
312 .map(|(index, (config, mut reflect))| {
313 let reflection = reflect.reflect(index, semantics)?;
314 let wgsl = reflect.compile(NagaLoweringOptions {
315 write_pcb_as_ubo: true,
316 sampler_bind_group: 1,
317 })?;
318
319 let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
320 let push_size = reflection
321 .push_constant
322 .as_ref()
323 .map_or(0, |push| push.size as wgpu::BufferAddress);
324
325 let uniform_storage = UniformStorage::new_with_storage(
326 WgpuStagedBuffer::new(
327 &device,
328 wgpu::BufferUsages::UNIFORM,
329 ubo_size as wgpu::BufferAddress,
330 Some("ubo"),
331 ),
332 WgpuStagedBuffer::new(
333 &device,
334 wgpu::BufferUsages::UNIFORM,
335 push_size as wgpu::BufferAddress,
336 Some("push"),
337 ),
338 );
339
340 let uniform_bindings =
341 reflection.meta.create_binding_map(|param| param.offset());
342
343 let render_pass_format: Option<wgpu::TextureFormat> =
344 if let Some(format) = config.meta.get_format_override() {
345 format.into()
346 } else {
347 config.data.format.into()
348 };
349
350 let graphics_pipeline = WgpuGraphicsPipeline::new(
351 &device,
352 &wgsl,
353 &reflection,
354 render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
355 adapter_info,
356 disable_cache,
357 );
358
359 Ok(FilterPass {
360 reflection,
361 uniform_storage,
362 uniform_bindings,
363 source: config.data,
364 meta: config.meta,
365 graphics_pipeline,
366 })
367 })
368 .collect();
369 filters
370 };
371
372 #[cfg(target_arch = "wasm32")]
373 let filters = filter_creation_fn();
374
375 #[cfg(not(target_arch = "wasm32"))]
376 let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
377 .stack_size(10 * 1048576)
379 .build()
380 {
381 thread_pool.install(|| filter_creation_fn())
382 } else {
383 filter_creation_fn()
384 };
385
386 let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
387 let filters = filters?;
388 Ok(filters.into_boxed_slice())
389 }
390
391 pub fn frame<'a>(
393 &mut self,
394 input: &wgpu::Texture,
395 viewport: &Viewport<WgpuOutputView<'a>>,
396 cmd: &mut wgpu::CommandEncoder,
397 frame_count: usize,
398 options: Option<&FrameOptionsWgpu>,
399 ) -> error::Result<()> {
400 let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
401 let passes = &mut self.passes[0..max];
402
403 if let Some(options) = &options {
404 if options.clear_history {
405 for history in &mut self.history_framebuffers {
406 history.clear(cmd);
407 }
408 }
409 }
410
411 if passes.is_empty() {
412 return Ok(());
413 }
414
415 let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
416
417 let filter = passes[0].meta.filter;
418 let wrap_mode = passes[0].meta.wrap_mode;
419
420 for (texture, image) in self
422 .common
423 .history_textures
424 .iter_mut()
425 .zip(self.history_framebuffers.iter())
426 {
427 *texture = Some(image.as_input(filter, wrap_mode));
428 }
429
430 let original = InputImage {
431 image: input.clone(),
432 view: original_image_view,
433 wrap_mode,
434 filter_mode: filter,
435 mip_filter: filter,
436 };
437
438 let mut source = original.clone();
439
440 std::mem::swap(
442 &mut self.output_framebuffers,
443 &mut self.feedback_framebuffers,
444 );
445
446 OwnedImage::scale_framebuffers_with_context(
448 source.image.size().into(),
449 viewport.output.size,
450 original.image.size().into(),
451 &mut self.output_framebuffers,
452 &mut self.feedback_framebuffers,
453 passes,
454 &self.common.device,
455 Some(&mut |index: usize,
456 pass: &FilterPass,
457 output: &OwnedImage,
458 feedback: &OwnedImage| {
459 self.common.feedback_textures[index] =
461 Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
462 self.common.output_textures[index] =
463 Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode));
464 Ok(())
465 }),
466 )?;
467
468 let passes_len = passes.len();
469 let (pass, last) = passes.split_at_mut(passes_len - 1);
470
471 let options = options.unwrap_or(&self.default_frame_options);
472
473 for (index, pass) in pass.iter_mut().enumerate() {
474 source.filter_mode = pass.meta.filter;
475 source.wrap_mode = pass.meta.wrap_mode;
476 source.mip_filter = pass.meta.filter;
477
478 let target = &self.output_framebuffers[index];
479 let output_image = WgpuOutputView::from(target);
480 let out = RenderTarget::identity(&output_image)?;
481
482 pass.draw(
483 cmd,
484 index,
485 &self.common,
486 pass.meta.get_frame_count(frame_count),
487 options,
488 viewport,
489 &original,
490 &source,
491 &out,
492 QuadType::Offscreen,
493 )?;
494
495 if target.max_miplevels > 1 && !self.disable_mipmaps {
496 let sampler = self.common.samplers.get(
497 WrapMode::ClampToEdge,
498 FilterMode::Linear,
499 FilterMode::Nearest,
500 );
501
502 target.generate_mipmaps(&self.common.device, cmd, &mut self.mipmapper, &sampler);
503 }
504
505 source = self.common.output_textures[index].clone().unwrap();
506 }
507
508 assert_eq!(last.len(), 1);
510
511 if let Some(pass) = last.iter_mut().next() {
512 let index = passes_len - 1;
513 if !pass.graphics_pipeline.has_format(viewport.output.format) {
514 pass.graphics_pipeline
516 .recompile(&self.common.device, viewport.output.format);
517 }
518
519 source.filter_mode = pass.meta.filter;
520 source.wrap_mode = pass.meta.wrap_mode;
521 source.mip_filter = pass.meta.filter;
522
523 if self.draw_last_pass_feedback {
524 let target = &self.output_framebuffers[index];
525 let output_image = WgpuOutputView::from(target);
526 let out = RenderTarget::viewport_with_output(&output_image, viewport);
527
528 pass.draw(
529 cmd,
530 index,
531 &self.common,
532 pass.meta.get_frame_count(frame_count),
533 options,
534 viewport,
535 &original,
536 &source,
537 &out,
538 QuadType::Final,
539 )?;
540 }
541
542 let out = RenderTarget::viewport(viewport);
543 pass.draw(
544 cmd,
545 index,
546 &self.common,
547 pass.meta.get_frame_count(frame_count),
548 options,
549 viewport,
550 &original,
551 &source,
552 &out,
553 QuadType::Final,
554 )?;
555 }
556
557 self.push_history(&input, cmd);
558 Ok(())
559 }
560}