1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::ThreadPoolBuilder;
20
21use crate::buffer::WgpuStagedBuffer;
22use crate::draw_quad::DrawQuad;
23use librashader_common::{FilterMode, Size, Viewport, WrapMode};
24use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
25use librashader_runtime::framebuffer::FramebufferInit;
26use librashader_runtime::render_target::RenderTarget;
27use librashader_runtime::scaling::ScaleFramebuffer;
28
29use crate::error;
30use crate::error::FilterChainError;
31use crate::filter_pass::FilterPass;
32use crate::framebuffer::WgpuOutputView;
33use crate::graphics_pipeline::WgpuGraphicsPipeline;
34use crate::luts::LutTexture;
35use crate::mipmap::MipmapGen;
36use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
37use crate::samplers::SamplerSet;
38use crate::texture::{InputImage, OwnedImage};
39
40#[cfg(feature = "native")]
41mod compile {
42 use super::*;
43 use librashader_pack::{PassResource, TextureResource};
44
45 #[cfg(feature = "nightly")]
46 pub type ShaderPassMeta =
47 ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
48
49 #[cfg(not(feature = "nightly"))]
50 pub type ShaderPassMeta =
51 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
52
53 #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
54 pub fn compile_passes(
55 shaders: Vec<PassResource>,
56 textures: &[TextureResource],
57 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
58 let (passes, semantics) = WGSL::compile_preset_passes::<
59 SpirvCompilation,
60 Naga,
61 FilterChainError,
62 >(shaders, textures.iter().map(|t| &t.meta))?;
63 Ok((passes, semantics))
64 }
65}
66
67#[cfg(feature = "wgsl_preset_pack")]
68mod compile_wgsl {
69 use super::*;
70 use librashader_pack::{PassResource, TextureResource};
71 use librashader_reflect::front::WgslCompilation;
72
73 #[cfg(feature = "nightly")]
74 pub type ShaderPassMeta =
75 ShaderPassArtifact<impl CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>;
76
77 #[cfg(not(feature = "nightly"))]
78 pub type ShaderPassMeta =
79 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>>;
80
81 #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
82 pub fn compile_passes(
83 shaders: Vec<PassResource>,
84 textures: &[TextureResource],
85 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
86 let (passes, semantics) = WGSL::compile_preset_passes::<
87 WgslCompilation,
88 Naga,
89 FilterChainError,
90 >(shaders, textures.iter().map(|t| &t.meta))?;
91 Ok((passes, semantics))
92 }
93}
94
95#[cfg(all(feature = "native", not(feature = "wgsl_preset_pack")))]
96use compile::{compile_passes, ShaderPassMeta};
97
98#[cfg(any(not(feature = "native"), feature = "wgsl_preset_pack"))]
99use compile_wgsl::{compile_passes, ShaderPassMeta};
100
101use librashader_pack::{ShaderPresetPack, TextureResource};
102use librashader_runtime::parameters::RuntimeParameters;
103
104pub struct FilterChainWgpu {
106 pub(crate) common: FilterCommon,
107 passes: Box<[FilterPass]>,
108 output_framebuffers: Box<[OwnedImage]>,
109 feedback_framebuffers: Box<[OwnedImage]>,
110 history_framebuffers: VecDeque<OwnedImage>,
111 disable_mipmaps: bool,
112 mipmapper: MipmapGen,
113 default_frame_options: FrameOptionsWgpu,
114 draw_last_pass_feedback: bool,
115}
116
117pub(crate) struct FilterCommon {
118 pub output_textures: Box<[Option<InputImage>]>,
119 pub feedback_textures: Box<[Option<InputImage>]>,
120 pub history_textures: Box<[Option<InputImage>]>,
121 pub luts: FastHashMap<usize, LutTexture>,
122 pub samplers: SamplerSet,
123 pub config: RuntimeParameters,
124 pub(crate) draw_quad: DrawQuad,
125 pub(crate) device: wgpu::Device,
126 pub(crate) queue: wgpu::Queue,
127}
128
129impl FilterChainWgpu {
130 #[cfg(feature = "native")]
132 pub fn load_from_path(
133 path: impl AsRef<Path>,
134 features: ShaderFeatures,
135 device: &wgpu::Device,
136 queue: &wgpu::Queue,
137 options: Option<&FilterChainOptionsWgpu>,
138 ) -> error::Result<FilterChainWgpu> {
139 let preset = ShaderPreset::try_parse(path, features)?;
141
142 Self::load_from_preset(preset, device, queue, options)
143 }
144
145 #[cfg(feature = "native")]
147 pub fn load_from_preset(
148 preset: ShaderPreset,
149 device: &wgpu::Device,
150 queue: &wgpu::Queue,
151 options: Option<&FilterChainOptionsWgpu>,
152 ) -> error::Result<FilterChainWgpu> {
153 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
154 Self::load_from_pack(preset, device, queue, options)
155 }
156
157 pub fn load_from_pack(
159 preset: ShaderPresetPack,
160 device: &wgpu::Device,
161 queue: &wgpu::Queue,
162 options: Option<&FilterChainOptionsWgpu>,
163 ) -> error::Result<FilterChainWgpu> {
164 let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
165 label: Some("librashader load cmd"),
166 });
167 let filter_chain =
168 Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
169
170 let cmd = cmd.finish();
171
172 let index = queue.submit([cmd]);
174 device.poll(wgpu::PollType::Wait {
175 submission_index: Some(index),
176 timeout: None,
177 })?;
178
179 Ok(filter_chain)
180 }
181
182 #[cfg(feature = "native")]
190 pub fn load_from_preset_deferred(
191 preset: ShaderPreset,
192 device: &wgpu::Device,
193 queue: &wgpu::Queue,
194 cmd: &mut wgpu::CommandEncoder,
195 options: Option<&FilterChainOptionsWgpu>,
196 ) -> error::Result<FilterChainWgpu> {
197 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
198 Self::load_from_pack_deferred(preset, device, queue, cmd, options)
199 }
200
201 pub fn load_from_pack_deferred(
209 preset: ShaderPresetPack,
210 device: &wgpu::Device,
211 queue: &wgpu::Queue,
212 cmd: &mut wgpu::CommandEncoder,
213 options: Option<&FilterChainOptionsWgpu>,
214 ) -> error::Result<FilterChainWgpu> {
215 let config = RuntimeParameters::new(&preset);
216
217 let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
218
219 let disable_cache = options.map_or(true, |o| !o.enable_cache);
221
222 let filters = Self::init_passes(
224 &device,
225 passes,
226 &semantics,
227 options.and_then(|o| o.adapter_info.as_ref()),
228 disable_cache,
229 )?;
230
231 let samplers = SamplerSet::new(&device);
232 let mut mipmapper = MipmapGen::new(&device);
233 let luts = FilterChainWgpu::load_luts(
234 &device,
235 &queue,
236 cmd,
237 &mut mipmapper,
238 &samplers,
239 preset.textures,
240 )?;
241 let framebuffer_gen = || {
243 Ok::<_, error::FilterChainError>(OwnedImage::new(
244 &device,
245 Size::new(1, 1),
246 1,
247 wgpu::TextureFormat::Bgra8Unorm,
248 ))
249 };
250 let input_gen = || None;
251 let framebuffer_init = FramebufferInit::new(
252 filters.iter().map(|f| &f.reflection.meta),
253 &framebuffer_gen,
254 &input_gen,
255 );
256
257 let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
260 let (feedback_framebuffers, feedback_textures) =
263 framebuffer_init.init_output_framebuffers()?;
264 let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
267
268 let draw_quad = DrawQuad::new(&device);
269
270 Ok(FilterChainWgpu {
271 draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
272 common: FilterCommon {
273 luts,
274 samplers,
275 config,
276 draw_quad,
277 device: device.clone(),
278 queue: queue.clone(),
279 output_textures,
280 feedback_textures,
281 history_textures,
282 },
283 passes: filters,
284 output_framebuffers,
285 feedback_framebuffers,
286 history_framebuffers,
287 disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
288 mipmapper,
289 default_frame_options: Default::default(),
290 })
291 }
292
293 fn load_luts(
294 device: &wgpu::Device,
295 queue: &wgpu::Queue,
296 cmd: &mut wgpu::CommandEncoder,
297 mipmapper: &mut MipmapGen,
298 sampler_set: &SamplerSet,
299 textures: Vec<TextureResource>,
300 ) -> error::Result<FastHashMap<usize, LutTexture>> {
301 let mut luts = FastHashMap::default();
302
303 #[cfg(not(target_arch = "wasm32"))]
304 let images_iter = textures.into_par_iter();
305
306 #[cfg(target_arch = "wasm32")]
307 let images_iter = textures.into_iter();
308
309 let textures = images_iter
310 .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
311 .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
312 for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
313 let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
314 luts.insert(index, texture);
315 }
316 Ok(luts)
317 }
318
319 fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
320 if let Some(mut back) = self.history_framebuffers.pop_back() {
321 if back.image.size() != input.size() || input.format() != back.image.format() {
322 let _old_back = std::mem::replace(
324 &mut back,
325 OwnedImage::new(
326 &self.common.device,
327 input.size().into(),
328 1,
329 input.format().into(),
330 ),
331 );
332 }
333
334 back.copy_from(&self.common.device, cmd, input);
335
336 self.history_framebuffers.push_front(back)
337 }
338 }
339
340 fn init_passes(
341 device: &wgpu::Device,
342 passes: Vec<ShaderPassMeta>,
343 semantics: &ShaderSemantics,
344 adapter_info: Option<&wgpu::AdapterInfo>,
345 disable_cache: bool,
346 ) -> error::Result<Box<[FilterPass]>> {
347 let filter_creation_fn = || {
348 #[cfg(not(target_arch = "wasm32"))]
349 let passes_iter = passes.into_par_iter();
350 #[cfg(target_arch = "wasm32")]
351 let passes_iter = passes.into_iter();
352
353 let filters: Vec<error::Result<FilterPass>> = passes_iter
354 .enumerate()
355 .map(|(index, (config, mut reflect))| {
356 let reflection = reflect.reflect(index, semantics)?;
357 let wgsl = reflect.compile(NagaLoweringOptions {
358 write_pcb_as_ubo: true,
359 sampler_bind_group: 1,
360 suppress_derivative_uniformity: true,
361 })?;
362
363 let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
364 let push_size = reflection
365 .push_constant
366 .as_ref()
367 .map_or(0, |push| push.size as wgpu::BufferAddress);
368
369 let uniform_storage = UniformStorage::new_with_storage(
370 WgpuStagedBuffer::new(
371 &device,
372 wgpu::BufferUsages::UNIFORM,
373 ubo_size as wgpu::BufferAddress,
374 Some("ubo"),
375 ),
376 WgpuStagedBuffer::new(
377 &device,
378 wgpu::BufferUsages::UNIFORM,
379 push_size as wgpu::BufferAddress,
380 Some("push"),
381 ),
382 );
383
384 let uniform_bindings =
385 reflection.meta.create_binding_map(|param| param.offset());
386
387 let render_pass_format: Option<wgpu::TextureFormat> =
388 if let Some(format) = config.meta.get_format_override() {
389 format.into()
390 } else {
391 config.data.format.into()
392 };
393
394 let graphics_pipeline = WgpuGraphicsPipeline::new(
395 &device,
396 &wgsl,
397 &reflection,
398 render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
399 adapter_info,
400 disable_cache,
401 );
402
403 Ok(FilterPass {
404 reflection,
405 uniform_storage,
406 uniform_bindings,
407 source: config.data,
408 meta: config.meta,
409 graphics_pipeline,
410 })
411 })
412 .collect();
413 filters
414 };
415
416 #[cfg(target_arch = "wasm32")]
417 let filters = filter_creation_fn();
418
419 #[cfg(not(target_arch = "wasm32"))]
420 let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
421 .stack_size(10 * 1048576)
423 .build()
424 {
425 thread_pool.install(|| filter_creation_fn())
426 } else {
427 filter_creation_fn()
428 };
429
430 let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
431 let filters = filters?;
432 Ok(filters.into_boxed_slice())
433 }
434
435 pub fn frame<'a>(
437 &mut self,
438 input: &wgpu::Texture,
439 viewport: &Viewport<WgpuOutputView<'a>>,
440 cmd: &mut wgpu::CommandEncoder,
441 frame_count: usize,
442 options: Option<&FrameOptionsWgpu>,
443 ) -> error::Result<()> {
444 let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
445 let passes = &mut self.passes[0..max];
446
447 if let Some(options) = &options {
448 if options.clear_history {
449 for history in &mut self.history_framebuffers {
450 history.clear(cmd);
451 }
452 }
453 }
454
455 if passes.is_empty() {
456 return Ok(());
457 }
458
459 let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
460
461 let filter = passes[0].meta.filter;
462 let wrap_mode = passes[0].meta.wrap_mode;
463
464 for (texture, image) in self
466 .common
467 .history_textures
468 .iter_mut()
469 .zip(self.history_framebuffers.iter())
470 {
471 *texture = Some(image.as_input(filter, wrap_mode));
472 }
473
474 let original = InputImage {
475 image: input.clone(),
476 view: original_image_view,
477 wrap_mode,
478 filter_mode: filter,
479 mip_filter: filter,
480 };
481
482 let mut source = original.clone();
483
484 std::mem::swap(
486 &mut self.output_framebuffers,
487 &mut self.feedback_framebuffers,
488 );
489
490 OwnedImage::scale_framebuffers_with_context(
492 source.image.size().into(),
493 viewport.output.size,
494 original.image.size().into(),
495 &mut self.output_framebuffers,
496 &mut self.feedback_framebuffers,
497 passes,
498 &self.common.device,
499 Some(&mut |index: usize,
500 pass: &FilterPass,
501 output: &OwnedImage,
502 feedback: &OwnedImage| {
503 self.common.feedback_textures[index] =
505 Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
506 self.common.output_textures[index] =
507 Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode));
508 Ok(())
509 }),
510 )?;
511
512 let passes_len = passes.len();
513 let (pass, last) = passes.split_at_mut(passes_len - 1);
514
515 let options = options.unwrap_or(&self.default_frame_options);
516
517 for (index, pass) in pass.iter_mut().enumerate() {
518 source.filter_mode = pass.meta.filter;
519 source.wrap_mode = pass.meta.wrap_mode;
520 source.mip_filter = pass.meta.filter;
521
522 let target = &self.output_framebuffers[index];
523 let output_image = WgpuOutputView::from(target);
524 let out = RenderTarget::identity(&output_image)?;
525
526 pass.draw(
527 cmd,
528 index,
529 &self.common,
530 pass.meta.get_frame_count(frame_count),
531 options,
532 viewport,
533 &original,
534 &source,
535 &out,
536 None,
537 QuadType::Offscreen,
538 )?;
539
540 if target.max_miplevels > 1 && !self.disable_mipmaps {
541 let sampler = self.common.samplers.get(
542 WrapMode::ClampToEdge,
543 FilterMode::Linear,
544 FilterMode::Nearest,
545 );
546
547 target.generate_mipmaps(&self.common.device, cmd, &mut self.mipmapper, &sampler);
548 }
549
550 source = self.common.output_textures[index].clone().unwrap();
551 }
552
553 assert_eq!(last.len(), 1);
555
556 if let Some(pass) = last.iter_mut().next() {
557 let index = passes_len - 1;
558 if !pass.graphics_pipeline.has_format(viewport.output.format) {
559 pass.graphics_pipeline
561 .recompile(&self.common.device, viewport.output.format);
562 }
563
564 source.filter_mode = pass.meta.filter;
565 source.wrap_mode = pass.meta.wrap_mode;
566 source.mip_filter = pass.meta.filter;
567
568 let output_size_override = if self.draw_last_pass_feedback {
575 let target = &self.output_framebuffers[index];
576 let output_image = WgpuOutputView::from(target);
577 let out = RenderTarget::viewport_with_output(&output_image, viewport);
578
579 pass.draw(
580 cmd,
581 index,
582 &self.common,
583 pass.meta.get_frame_count(frame_count),
584 options,
585 viewport,
586 &original,
587 &source,
588 &out,
589 None,
590 QuadType::Final,
591 )?;
592 Some(target.size)
593 } else {
594 None
595 };
596
597 let out = RenderTarget::viewport(viewport);
598 pass.draw(
599 cmd,
600 index,
601 &self.common,
602 pass.meta.get_frame_count(frame_count),
603 options,
604 viewport,
605 &original,
606 &source,
607 &out,
608 output_size_override,
609 QuadType::Final,
610 )?;
611 }
612
613 self.push_history(&input, cmd);
614 Ok(())
615 }
616}