1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18use rayon::ThreadPoolBuilder;
19use std::sync::Arc;
20
21use crate::buffer::WgpuStagedBuffer;
22use crate::draw_quad::DrawQuad;
23use librashader_common::{FilterMode, Size, Viewport, WrapMode};
24use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
25use librashader_runtime::framebuffer::FramebufferInit;
26use librashader_runtime::render_target::RenderTarget;
27use librashader_runtime::scaling::ScaleFramebuffer;
28use wgpu::{Device, TextureFormat};
29
30use crate::error;
31use crate::error::FilterChainError;
32use crate::filter_pass::FilterPass;
33use crate::framebuffer::WgpuOutputView;
34use crate::graphics_pipeline::WgpuGraphicsPipeline;
35use crate::luts::LutTexture;
36use crate::mipmap::MipmapGen;
37use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
38use crate::samplers::SamplerSet;
39use crate::texture::{InputImage, OwnedImage};
40
41mod compile {
42 use super::*;
43 use librashader_pack::{PassResource, TextureResource};
44
45 #[cfg(not(feature = "stable"))]
46 pub type ShaderPassMeta =
47 ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
48
49 #[cfg(feature = "stable")]
50 pub type ShaderPassMeta =
51 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
52
53 pub fn compile_passes(
54 shaders: Vec<PassResource>,
55 textures: &[TextureResource],
56 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
57 let (passes, semantics) = WGSL::compile_preset_passes::<
58 SpirvCompilation,
59 Naga,
60 FilterChainError,
61 >(shaders, textures.iter().map(|t| &t.meta))?;
62 Ok((passes, semantics))
63 }
64}
65
66use compile::{compile_passes, ShaderPassMeta};
67use librashader_pack::{ShaderPresetPack, TextureResource};
68use librashader_runtime::parameters::RuntimeParameters;
69
70pub struct FilterChainWgpu {
72 pub(crate) common: FilterCommon,
73 passes: Box<[FilterPass]>,
74 output_framebuffers: Box<[OwnedImage]>,
75 feedback_framebuffers: Box<[OwnedImage]>,
76 history_framebuffers: VecDeque<OwnedImage>,
77 disable_mipmaps: bool,
78 mipmapper: MipmapGen,
79 default_frame_options: FrameOptionsWgpu,
80 draw_last_pass_feedback: bool,
81}
82
83pub(crate) struct FilterCommon {
84 pub output_textures: Box<[Option<InputImage>]>,
85 pub feedback_textures: Box<[Option<InputImage>]>,
86 pub history_textures: Box<[Option<InputImage>]>,
87 pub luts: FastHashMap<usize, LutTexture>,
88 pub samplers: SamplerSet,
89 pub config: RuntimeParameters,
90 pub(crate) draw_quad: DrawQuad,
91 pub(crate) device: Arc<Device>,
92 pub(crate) queue: Arc<wgpu::Queue>,
93}
94
95impl FilterChainWgpu {
96 pub fn load_from_path(
98 path: impl AsRef<Path>,
99 features: ShaderFeatures,
100 device: Arc<Device>,
101 queue: Arc<wgpu::Queue>,
102 options: Option<&FilterChainOptionsWgpu>,
103 ) -> error::Result<FilterChainWgpu> {
104 let preset = ShaderPreset::try_parse(path, features)?;
106
107 Self::load_from_preset(preset, device, queue, options)
108 }
109
110 pub fn load_from_preset(
112 preset: ShaderPreset,
113 device: Arc<Device>,
114 queue: Arc<wgpu::Queue>,
115 options: Option<&FilterChainOptionsWgpu>,
116 ) -> error::Result<FilterChainWgpu> {
117 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
118 Self::load_from_pack(preset, device, queue, options)
119 }
120
121 pub fn load_from_pack(
123 preset: ShaderPresetPack,
124 device: Arc<Device>,
125 queue: Arc<wgpu::Queue>,
126 options: Option<&FilterChainOptionsWgpu>,
127 ) -> error::Result<FilterChainWgpu> {
128 let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
129 label: Some("librashader load cmd"),
130 });
131 let filter_chain = Self::load_from_pack_deferred(
132 preset,
133 Arc::clone(&device),
134 Arc::clone(&queue),
135 &mut cmd,
136 options,
137 )?;
138
139 let cmd = cmd.finish();
140
141 let index = queue.submit([cmd]);
143 device.poll(wgpu::Maintain::WaitForSubmissionIndex(index));
144
145 Ok(filter_chain)
146 }
147
148 pub fn load_from_preset_deferred(
156 preset: ShaderPreset,
157 device: Arc<Device>,
158 queue: Arc<wgpu::Queue>,
159 cmd: &mut wgpu::CommandEncoder,
160 options: Option<&FilterChainOptionsWgpu>,
161 ) -> error::Result<FilterChainWgpu> {
162 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
163 Self::load_from_pack_deferred(preset, device, queue, cmd, options)
164 }
165
166 pub fn load_from_pack_deferred(
174 preset: ShaderPresetPack,
175 device: Arc<Device>,
176 queue: Arc<wgpu::Queue>,
177 cmd: &mut wgpu::CommandEncoder,
178 options: Option<&FilterChainOptionsWgpu>,
179 ) -> error::Result<FilterChainWgpu> {
180 let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
181
182 let disable_cache = options.map_or(true, |o| !o.enable_cache);
184
185 let filters = Self::init_passes(
187 &device,
188 passes,
189 &semantics,
190 options.and_then(|o| o.adapter_info.as_ref()),
191 disable_cache,
192 )?;
193
194 let samplers = SamplerSet::new(&device);
195 let mut mipmapper = MipmapGen::new(&device);
196 let luts = FilterChainWgpu::load_luts(
197 &device,
198 &queue,
199 cmd,
200 &mut mipmapper,
201 &samplers,
202 preset.textures,
203 )?;
204 let framebuffer_gen = || {
206 Ok::<_, error::FilterChainError>(OwnedImage::new(
207 &device,
208 Size::new(1, 1),
209 1,
210 TextureFormat::Bgra8Unorm,
211 ))
212 };
213 let input_gen = || None;
214 let framebuffer_init = FramebufferInit::new(
215 filters.iter().map(|f| &f.reflection.meta),
216 &framebuffer_gen,
217 &input_gen,
218 );
219
220 let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
223 let (feedback_framebuffers, feedback_textures) =
226 framebuffer_init.init_output_framebuffers()?;
227 let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
230
231 let draw_quad = DrawQuad::new(&device);
232
233 Ok(FilterChainWgpu {
234 draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
235 common: FilterCommon {
236 luts,
237 samplers,
238 config: RuntimeParameters::new(preset.pass_count as usize, preset.parameters),
239 draw_quad,
240 device,
241 queue,
242 output_textures,
243 feedback_textures,
244 history_textures,
245 },
246 passes: filters,
247 output_framebuffers,
248 feedback_framebuffers,
249 history_framebuffers,
250 disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
251 mipmapper,
252 default_frame_options: Default::default(),
253 })
254 }
255
256 fn load_luts(
257 device: &wgpu::Device,
258 queue: &wgpu::Queue,
259 cmd: &mut wgpu::CommandEncoder,
260 mipmapper: &mut MipmapGen,
261 sampler_set: &SamplerSet,
262 textures: Vec<TextureResource>,
263 ) -> error::Result<FastHashMap<usize, LutTexture>> {
264 let mut luts = FastHashMap::default();
265
266 #[cfg(not(target_arch = "wasm32"))]
267 let images_iter = textures.into_par_iter();
268
269 #[cfg(target_arch = "wasm32")]
270 let images_iter = textures.into_iter();
271
272 let textures = images_iter
273 .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
274 .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
275 for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
276 let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
277 luts.insert(index, texture);
278 }
279 Ok(luts)
280 }
281
282 fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
283 if let Some(mut back) = self.history_framebuffers.pop_back() {
284 if back.image.size() != input.size() || input.format() != back.image.format() {
285 let _old_back = std::mem::replace(
287 &mut back,
288 OwnedImage::new(
289 &self.common.device,
290 input.size().into(),
291 1,
292 input.format().into(),
293 ),
294 );
295 }
296
297 back.copy_from(&self.common.device, cmd, input);
298
299 self.history_framebuffers.push_front(back)
300 }
301 }
302
303 fn init_passes(
304 device: &wgpu::Device,
305 passes: Vec<ShaderPassMeta>,
306 semantics: &ShaderSemantics,
307 adapter_info: Option<&wgpu::AdapterInfo>,
308 disable_cache: bool,
309 ) -> error::Result<Box<[FilterPass]>> {
310 #[cfg(not(target_arch = "wasm32"))]
311 let filter_creation_fn = || {
312 let passes_iter = passes.into_par_iter();
313 #[cfg(target_arch = "wasm32")]
314 let passes_iter = passes.into_iter();
315
316 let filters: Vec<error::Result<FilterPass>> = passes_iter
317 .enumerate()
318 .map(|(index, (config, mut reflect))| {
319 let reflection = reflect.reflect(index, semantics)?;
320 let wgsl = reflect.compile(NagaLoweringOptions {
321 write_pcb_as_ubo: true,
322 sampler_bind_group: 1,
323 })?;
324
325 let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
326 let push_size = reflection
327 .push_constant
328 .as_ref()
329 .map_or(0, |push| push.size as wgpu::BufferAddress);
330
331 let uniform_storage = UniformStorage::new_with_storage(
332 WgpuStagedBuffer::new(
333 &device,
334 wgpu::BufferUsages::UNIFORM,
335 ubo_size as wgpu::BufferAddress,
336 Some("ubo"),
337 ),
338 WgpuStagedBuffer::new(
339 &device,
340 wgpu::BufferUsages::UNIFORM,
341 push_size as wgpu::BufferAddress,
342 Some("push"),
343 ),
344 );
345
346 let uniform_bindings =
347 reflection.meta.create_binding_map(|param| param.offset());
348
349 let render_pass_format: Option<TextureFormat> =
350 if let Some(format) = config.meta.get_format_override() {
351 format.into()
352 } else {
353 config.data.format.into()
354 };
355
356 let graphics_pipeline = WgpuGraphicsPipeline::new(
357 &device,
358 &wgsl,
359 &reflection,
360 render_pass_format.unwrap_or(TextureFormat::Rgba8Unorm),
361 adapter_info,
362 disable_cache,
363 );
364
365 Ok(FilterPass {
366 reflection,
367 uniform_storage,
368 uniform_bindings,
369 source: config.data,
370 meta: config.meta,
371 graphics_pipeline,
372 })
373 })
374 .collect();
375 filters
376 };
377
378 #[cfg(target_arch = "wasm32")]
379 let filters = filter_creation_fn();
380
381 #[cfg(not(target_arch = "wasm32"))]
382 let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
383 .stack_size(10 * 1048576)
385 .build()
386 {
387 thread_pool.install(|| filter_creation_fn())
388 } else {
389 filter_creation_fn()
390 };
391
392 let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
393 let filters = filters?;
394 Ok(filters.into_boxed_slice())
395 }
396
397 pub fn frame<'a>(
399 &mut self,
400 input: Arc<wgpu::Texture>,
401 viewport: &Viewport<WgpuOutputView<'a>>,
402 cmd: &mut wgpu::CommandEncoder,
403 frame_count: usize,
404 options: Option<&FrameOptionsWgpu>,
405 ) -> error::Result<()> {
406 let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
407 let passes = &mut self.passes[0..max];
408
409 if let Some(options) = &options {
410 if options.clear_history {
411 for history in &mut self.history_framebuffers {
412 history.clear(cmd);
413 }
414 }
415 }
416
417 if passes.is_empty() {
418 return Ok(());
419 }
420
421 let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
422
423 let filter = passes[0].meta.filter;
424 let wrap_mode = passes[0].meta.wrap_mode;
425
426 for (texture, image) in self
428 .common
429 .history_textures
430 .iter_mut()
431 .zip(self.history_framebuffers.iter())
432 {
433 *texture = Some(image.as_input(filter, wrap_mode));
434 }
435
436 let original = InputImage {
437 image: Arc::clone(&input),
438 view: Arc::new(original_image_view),
439 wrap_mode,
440 filter_mode: filter,
441 mip_filter: filter,
442 };
443
444 let mut source = original.clone();
445
446 std::mem::swap(
448 &mut self.output_framebuffers,
449 &mut self.feedback_framebuffers,
450 );
451
452 OwnedImage::scale_framebuffers_with_context(
454 source.image.size().into(),
455 viewport.output.size,
456 original.image.size().into(),
457 &mut self.output_framebuffers,
458 &mut self.feedback_framebuffers,
459 passes,
460 &self.common.device,
461 Some(&mut |index: usize,
462 pass: &FilterPass,
463 output: &OwnedImage,
464 feedback: &OwnedImage| {
465 self.common.feedback_textures[index] =
467 Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
468 self.common.output_textures[index] =
469 Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode));
470 Ok(())
471 }),
472 )?;
473
474 let passes_len = passes.len();
475 let (pass, last) = passes.split_at_mut(passes_len - 1);
476
477 let options = options.unwrap_or(&self.default_frame_options);
478
479 for (index, pass) in pass.iter_mut().enumerate() {
480 source.filter_mode = pass.meta.filter;
481 source.wrap_mode = pass.meta.wrap_mode;
482 source.mip_filter = pass.meta.filter;
483
484 let target = &self.output_framebuffers[index];
485 let output_image = WgpuOutputView::from(target);
486 let out = RenderTarget::identity(&output_image)?;
487
488 pass.draw(
489 cmd,
490 index,
491 &self.common,
492 pass.meta.get_frame_count(frame_count),
493 options,
494 viewport,
495 &original,
496 &source,
497 &out,
498 QuadType::Offscreen,
499 )?;
500
501 if target.max_miplevels > 1 && !self.disable_mipmaps {
502 let sampler = self.common.samplers.get(
503 WrapMode::ClampToEdge,
504 FilterMode::Linear,
505 FilterMode::Nearest,
506 );
507
508 target.generate_mipmaps(&self.common.device, cmd, &mut self.mipmapper, &sampler);
509 }
510
511 source = self.common.output_textures[index].clone().unwrap();
512 }
513
514 assert_eq!(last.len(), 1);
516
517 if let Some(pass) = last.iter_mut().next() {
518 let index = passes_len - 1;
519 if !pass.graphics_pipeline.has_format(viewport.output.format) {
520 pass.graphics_pipeline
522 .recompile(&self.common.device, viewport.output.format);
523 }
524
525 source.filter_mode = pass.meta.filter;
526 source.wrap_mode = pass.meta.wrap_mode;
527 source.mip_filter = pass.meta.filter;
528
529 if self.draw_last_pass_feedback {
530 let target = &self.output_framebuffers[index];
531 let output_image = WgpuOutputView::from(target);
532 let out = RenderTarget::viewport_with_output(&output_image, viewport);
533
534 pass.draw(
535 cmd,
536 index,
537 &self.common,
538 pass.meta.get_frame_count(frame_count),
539 options,
540 viewport,
541 &original,
542 &source,
543 &out,
544 QuadType::Final,
545 )?;
546 }
547
548 let out = RenderTarget::viewport(viewport);
549 pass.draw(
550 cmd,
551 index,
552 &self.common,
553 pass.meta.get_frame_count(frame_count),
554 options,
555 viewport,
556 &original,
557 &source,
558 &out,
559 QuadType::Final,
560 )?;
561 }
562
563 self.push_history(&input, cmd);
564 Ok(())
565 }
566}