1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18use rayon::ThreadPoolBuilder;
19
20use crate::buffer::WgpuStagedBuffer;
21use crate::draw_quad::DrawQuad;
22use librashader_common::{FilterMode, Size, Viewport, WrapMode};
23use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
24use librashader_runtime::framebuffer::FramebufferInit;
25use librashader_runtime::render_target::RenderTarget;
26use librashader_runtime::scaling::ScaleFramebuffer;
27
28use crate::error;
29use crate::error::FilterChainError;
30use crate::filter_pass::FilterPass;
31use crate::framebuffer::WgpuOutputView;
32use crate::graphics_pipeline::WgpuGraphicsPipeline;
33use crate::luts::LutTexture;
34use crate::mipmap::MipmapGen;
35use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
36use crate::samplers::SamplerSet;
37use crate::texture::{InputImage, OwnedImage};
38
39mod compile {
40 use super::*;
41 use librashader_pack::{PassResource, TextureResource};
42
43 #[cfg(not(feature = "stable"))]
44 pub type ShaderPassMeta =
45 ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
46
47 #[cfg(feature = "stable")]
48 pub type ShaderPassMeta =
49 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
50
51 #[cfg_attr(not(feature = "stable"), define_opaque(ShaderPassMeta))]
52 pub fn compile_passes(
53 shaders: Vec<PassResource>,
54 textures: &[TextureResource],
55 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
56 let (passes, semantics) = WGSL::compile_preset_passes::<
57 SpirvCompilation,
58 Naga,
59 FilterChainError,
60 >(shaders, textures.iter().map(|t| &t.meta))?;
61 Ok((passes, semantics))
62 }
63}
64
65use compile::{compile_passes, ShaderPassMeta};
66use librashader_pack::{ShaderPresetPack, TextureResource};
67use librashader_runtime::parameters::RuntimeParameters;
68
69pub struct FilterChainWgpu {
71 pub(crate) common: FilterCommon,
72 passes: Box<[FilterPass]>,
73 output_framebuffers: Box<[OwnedImage]>,
74 feedback_framebuffers: Box<[OwnedImage]>,
75 history_framebuffers: VecDeque<OwnedImage>,
76 disable_mipmaps: bool,
77 mipmapper: MipmapGen,
78 default_frame_options: FrameOptionsWgpu,
79 draw_last_pass_feedback: bool,
80}
81
82pub(crate) struct FilterCommon {
83 pub output_textures: Box<[Option<InputImage>]>,
84 pub feedback_textures: Box<[Option<InputImage>]>,
85 pub history_textures: Box<[Option<InputImage>]>,
86 pub luts: FastHashMap<usize, LutTexture>,
87 pub samplers: SamplerSet,
88 pub config: RuntimeParameters,
89 pub(crate) draw_quad: DrawQuad,
90 pub(crate) device: wgpu::Device,
91 pub(crate) queue: wgpu::Queue,
92}
93
94impl FilterChainWgpu {
95 pub fn load_from_path(
97 path: impl AsRef<Path>,
98 features: ShaderFeatures,
99 device: &wgpu::Device,
100 queue: &wgpu::Queue,
101 options: Option<&FilterChainOptionsWgpu>,
102 ) -> error::Result<FilterChainWgpu> {
103 let preset = ShaderPreset::try_parse(path, features)?;
105
106 Self::load_from_preset(preset, device, queue, options)
107 }
108
109 pub fn load_from_preset(
111 preset: ShaderPreset,
112 device: &wgpu::Device,
113 queue: &wgpu::Queue,
114 options: Option<&FilterChainOptionsWgpu>,
115 ) -> error::Result<FilterChainWgpu> {
116 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
117 Self::load_from_pack(preset, device, queue, options)
118 }
119
120 pub fn load_from_pack(
122 preset: ShaderPresetPack,
123 device: &wgpu::Device,
124 queue: &wgpu::Queue,
125 options: Option<&FilterChainOptionsWgpu>,
126 ) -> error::Result<FilterChainWgpu> {
127 let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
128 label: Some("librashader load cmd"),
129 });
130 let filter_chain =
131 Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
132
133 let cmd = cmd.finish();
134
135 let index = queue.submit([cmd]);
137 device.poll(wgpu::PollType::Wait {
138 submission_index: Some(index),
139 timeout: None,
140 })?;
141
142 Ok(filter_chain)
143 }
144
145 pub fn load_from_preset_deferred(
153 preset: ShaderPreset,
154 device: &wgpu::Device,
155 queue: &wgpu::Queue,
156 cmd: &mut wgpu::CommandEncoder,
157 options: Option<&FilterChainOptionsWgpu>,
158 ) -> error::Result<FilterChainWgpu> {
159 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
160 Self::load_from_pack_deferred(preset, device, queue, cmd, options)
161 }
162
163 pub fn load_from_pack_deferred(
171 preset: ShaderPresetPack,
172 device: &wgpu::Device,
173 queue: &wgpu::Queue,
174 cmd: &mut wgpu::CommandEncoder,
175 options: Option<&FilterChainOptionsWgpu>,
176 ) -> error::Result<FilterChainWgpu> {
177 let config = RuntimeParameters::new(&preset);
178
179 let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
180
181 let disable_cache = options.map_or(true, |o| !o.enable_cache);
183
184 let filters = Self::init_passes(
186 &device,
187 passes,
188 &semantics,
189 options.and_then(|o| o.adapter_info.as_ref()),
190 disable_cache,
191 )?;
192
193 let samplers = SamplerSet::new(&device);
194 let mut mipmapper = MipmapGen::new(&device);
195 let luts = FilterChainWgpu::load_luts(
196 &device,
197 &queue,
198 cmd,
199 &mut mipmapper,
200 &samplers,
201 preset.textures,
202 )?;
203 let framebuffer_gen = || {
205 Ok::<_, error::FilterChainError>(OwnedImage::new(
206 &device,
207 Size::new(1, 1),
208 1,
209 wgpu::TextureFormat::Bgra8Unorm,
210 ))
211 };
212 let input_gen = || None;
213 let framebuffer_init = FramebufferInit::new(
214 filters.iter().map(|f| &f.reflection.meta),
215 &framebuffer_gen,
216 &input_gen,
217 );
218
219 let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
222 let (feedback_framebuffers, feedback_textures) =
225 framebuffer_init.init_output_framebuffers()?;
226 let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
229
230 let draw_quad = DrawQuad::new(&device);
231
232 Ok(FilterChainWgpu {
233 draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
234 common: FilterCommon {
235 luts,
236 samplers,
237 config,
238 draw_quad,
239 device: device.clone(),
240 queue: queue.clone(),
241 output_textures,
242 feedback_textures,
243 history_textures,
244 },
245 passes: filters,
246 output_framebuffers,
247 feedback_framebuffers,
248 history_framebuffers,
249 disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
250 mipmapper,
251 default_frame_options: Default::default(),
252 })
253 }
254
255 fn load_luts(
256 device: &wgpu::Device,
257 queue: &wgpu::Queue,
258 cmd: &mut wgpu::CommandEncoder,
259 mipmapper: &mut MipmapGen,
260 sampler_set: &SamplerSet,
261 textures: Vec<TextureResource>,
262 ) -> error::Result<FastHashMap<usize, LutTexture>> {
263 let mut luts = FastHashMap::default();
264
265 #[cfg(not(target_arch = "wasm32"))]
266 let images_iter = textures.into_par_iter();
267
268 #[cfg(target_arch = "wasm32")]
269 let images_iter = textures.into_iter();
270
271 let textures = images_iter
272 .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
273 .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
274 for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
275 let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
276 luts.insert(index, texture);
277 }
278 Ok(luts)
279 }
280
281 fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
282 if let Some(mut back) = self.history_framebuffers.pop_back() {
283 if back.image.size() != input.size() || input.format() != back.image.format() {
284 let _old_back = std::mem::replace(
286 &mut back,
287 OwnedImage::new(
288 &self.common.device,
289 input.size().into(),
290 1,
291 input.format().into(),
292 ),
293 );
294 }
295
296 back.copy_from(&self.common.device, cmd, input);
297
298 self.history_framebuffers.push_front(back)
299 }
300 }
301
302 fn init_passes(
303 device: &wgpu::Device,
304 passes: Vec<ShaderPassMeta>,
305 semantics: &ShaderSemantics,
306 adapter_info: Option<&wgpu::AdapterInfo>,
307 disable_cache: bool,
308 ) -> error::Result<Box<[FilterPass]>> {
309 #[cfg(not(target_arch = "wasm32"))]
310 let filter_creation_fn = || {
311 let passes_iter = passes.into_par_iter();
312 #[cfg(target_arch = "wasm32")]
313 let passes_iter = passes.into_iter();
314
315 let filters: Vec<error::Result<FilterPass>> = passes_iter
316 .enumerate()
317 .map(|(index, (config, mut reflect))| {
318 let reflection = reflect.reflect(index, semantics)?;
319 let wgsl = reflect.compile(NagaLoweringOptions {
320 write_pcb_as_ubo: true,
321 sampler_bind_group: 1,
322 })?;
323
324 let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
325 let push_size = reflection
326 .push_constant
327 .as_ref()
328 .map_or(0, |push| push.size as wgpu::BufferAddress);
329
330 let uniform_storage = UniformStorage::new_with_storage(
331 WgpuStagedBuffer::new(
332 &device,
333 wgpu::BufferUsages::UNIFORM,
334 ubo_size as wgpu::BufferAddress,
335 Some("ubo"),
336 ),
337 WgpuStagedBuffer::new(
338 &device,
339 wgpu::BufferUsages::UNIFORM,
340 push_size as wgpu::BufferAddress,
341 Some("push"),
342 ),
343 );
344
345 let uniform_bindings =
346 reflection.meta.create_binding_map(|param| param.offset());
347
348 let render_pass_format: Option<wgpu::TextureFormat> =
349 if let Some(format) = config.meta.get_format_override() {
350 format.into()
351 } else {
352 config.data.format.into()
353 };
354
355 let graphics_pipeline = WgpuGraphicsPipeline::new(
356 &device,
357 &wgsl,
358 &reflection,
359 render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
360 adapter_info,
361 disable_cache,
362 );
363
364 Ok(FilterPass {
365 reflection,
366 uniform_storage,
367 uniform_bindings,
368 source: config.data,
369 meta: config.meta,
370 graphics_pipeline,
371 })
372 })
373 .collect();
374 filters
375 };
376
377 #[cfg(target_arch = "wasm32")]
378 let filters = filter_creation_fn();
379
380 #[cfg(not(target_arch = "wasm32"))]
381 let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
382 .stack_size(10 * 1048576)
384 .build()
385 {
386 thread_pool.install(|| filter_creation_fn())
387 } else {
388 filter_creation_fn()
389 };
390
391 let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
392 let filters = filters?;
393 Ok(filters.into_boxed_slice())
394 }
395
396 pub fn frame<'a>(
398 &mut self,
399 input: &wgpu::Texture,
400 viewport: &Viewport<WgpuOutputView<'a>>,
401 cmd: &mut wgpu::CommandEncoder,
402 frame_count: usize,
403 options: Option<&FrameOptionsWgpu>,
404 ) -> error::Result<()> {
405 let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
406 let passes = &mut self.passes[0..max];
407
408 if let Some(options) = &options {
409 if options.clear_history {
410 for history in &mut self.history_framebuffers {
411 history.clear(cmd);
412 }
413 }
414 }
415
416 if passes.is_empty() {
417 return Ok(());
418 }
419
420 let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
421
422 let filter = passes[0].meta.filter;
423 let wrap_mode = passes[0].meta.wrap_mode;
424
425 for (texture, image) in self
427 .common
428 .history_textures
429 .iter_mut()
430 .zip(self.history_framebuffers.iter())
431 {
432 *texture = Some(image.as_input(filter, wrap_mode));
433 }
434
435 let original = InputImage {
436 image: input.clone(),
437 view: original_image_view,
438 wrap_mode,
439 filter_mode: filter,
440 mip_filter: filter,
441 };
442
443 let mut source = original.clone();
444
445 std::mem::swap(
447 &mut self.output_framebuffers,
448 &mut self.feedback_framebuffers,
449 );
450
451 OwnedImage::scale_framebuffers_with_context(
453 source.image.size().into(),
454 viewport.output.size,
455 original.image.size().into(),
456 &mut self.output_framebuffers,
457 &mut self.feedback_framebuffers,
458 passes,
459 &self.common.device,
460 Some(&mut |index: usize,
461 pass: &FilterPass,
462 output: &OwnedImage,
463 feedback: &OwnedImage| {
464 self.common.feedback_textures[index] =
466 Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
467 self.common.output_textures[index] =
468 Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode));
469 Ok(())
470 }),
471 )?;
472
473 let passes_len = passes.len();
474 let (pass, last) = passes.split_at_mut(passes_len - 1);
475
476 let options = options.unwrap_or(&self.default_frame_options);
477
478 for (index, pass) in pass.iter_mut().enumerate() {
479 source.filter_mode = pass.meta.filter;
480 source.wrap_mode = pass.meta.wrap_mode;
481 source.mip_filter = pass.meta.filter;
482
483 let target = &self.output_framebuffers[index];
484 let output_image = WgpuOutputView::from(target);
485 let out = RenderTarget::identity(&output_image)?;
486
487 pass.draw(
488 cmd,
489 index,
490 &self.common,
491 pass.meta.get_frame_count(frame_count),
492 options,
493 viewport,
494 &original,
495 &source,
496 &out,
497 QuadType::Offscreen,
498 )?;
499
500 if target.max_miplevels > 1 && !self.disable_mipmaps {
501 let sampler = self.common.samplers.get(
502 WrapMode::ClampToEdge,
503 FilterMode::Linear,
504 FilterMode::Nearest,
505 );
506
507 target.generate_mipmaps(&self.common.device, cmd, &mut self.mipmapper, &sampler);
508 }
509
510 source = self.common.output_textures[index].clone().unwrap();
511 }
512
513 assert_eq!(last.len(), 1);
515
516 if let Some(pass) = last.iter_mut().next() {
517 let index = passes_len - 1;
518 if !pass.graphics_pipeline.has_format(viewport.output.format) {
519 pass.graphics_pipeline
521 .recompile(&self.common.device, viewport.output.format);
522 }
523
524 source.filter_mode = pass.meta.filter;
525 source.wrap_mode = pass.meta.wrap_mode;
526 source.mip_filter = pass.meta.filter;
527
528 if self.draw_last_pass_feedback {
529 let target = &self.output_framebuffers[index];
530 let output_image = WgpuOutputView::from(target);
531 let out = RenderTarget::viewport_with_output(&output_image, viewport);
532
533 pass.draw(
534 cmd,
535 index,
536 &self.common,
537 pass.meta.get_frame_count(frame_count),
538 options,
539 viewport,
540 &original,
541 &source,
542 &out,
543 QuadType::Final,
544 )?;
545 }
546
547 let out = RenderTarget::viewport(viewport);
548 pass.draw(
549 cmd,
550 index,
551 &self.common,
552 pass.meta.get_frame_count(frame_count),
553 options,
554 viewport,
555 &original,
556 &source,
557 &out,
558 QuadType::Final,
559 )?;
560 }
561
562 self.push_history(&input, cmd);
563 Ok(())
564 }
565}