1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::ThreadPoolBuilder;
20
21use crate::buffer::WgpuStagedBuffer;
22use crate::draw_quad::DrawQuad;
23use librashader_common::{FilterMode, Size, Viewport, WrapMode};
24use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
25use librashader_runtime::framebuffer::{FramebufferInit, FramebufferPool};
26use librashader_runtime::render_target::RenderTarget;
27use librashader_runtime::scaling::ScaleFramebuffer;
28
29use crate::error;
30use crate::error::FilterChainError;
31use crate::filter_pass::FilterPass;
32use crate::framebuffer::WgpuOutputView;
33use crate::graphics_pipeline::WgpuGraphicsPipeline;
34use crate::luts::LutTexture;
35use crate::mipmap::MipmapGen;
36use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
37use crate::samplers::SamplerSet;
38use crate::texture::{InputImage, OwnedImage};
39
40#[cfg(feature = "native")]
41mod compile {
42 use super::*;
43 use librashader_pack::{PassResource, TextureResource};
44
45 #[cfg(feature = "nightly")]
46 pub type ShaderPassMeta =
47 ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
48
49 #[cfg(not(feature = "nightly"))]
50 pub type ShaderPassMeta =
51 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
52
53 #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
54 pub fn compile_passes(
55 shaders: Vec<PassResource>,
56 textures: &[TextureResource],
57 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
58 let (passes, semantics) = WGSL::compile_preset_passes::<
59 SpirvCompilation,
60 Naga,
61 FilterChainError,
62 >(shaders, textures.iter().map(|t| &t.meta))?;
63 Ok((passes, semantics))
64 }
65}
66
67#[cfg(feature = "wgsl_preset_pack")]
68mod compile_wgsl {
69 use super::*;
70 use librashader_pack::{PassResource, TextureResource};
71 use librashader_reflect::front::WgslCompilation;
72
73 #[cfg(feature = "nightly")]
74 pub type ShaderPassMeta =
75 ShaderPassArtifact<impl CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>;
76
77 #[cfg(not(feature = "nightly"))]
78 pub type ShaderPassMeta =
79 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>>;
80
81 #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
82 pub fn compile_passes(
83 shaders: Vec<PassResource>,
84 textures: &[TextureResource],
85 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
86 let (passes, semantics) = WGSL::compile_preset_passes::<
87 WgslCompilation,
88 Naga,
89 FilterChainError,
90 >(shaders, textures.iter().map(|t| &t.meta))?;
91 Ok((passes, semantics))
92 }
93}
94
95#[cfg(all(feature = "native", not(feature = "wgsl_preset_pack")))]
96use compile::{compile_passes, ShaderPassMeta};
97
98#[cfg(any(not(feature = "native"), feature = "wgsl_preset_pack"))]
99use compile_wgsl::{compile_passes, ShaderPassMeta};
100
101use librashader_pack::{ShaderPresetPack, TextureResource};
102use librashader_runtime::parameters::RuntimeParameters;
103
104pub struct FilterChainWgpu {
106 pub(crate) common: FilterCommon,
107 passes: Box<[FilterPass]>,
108 output_framebuffers: FramebufferPool<OwnedImage>,
109 feedback_framebuffers: FramebufferPool<OwnedImage>,
110 history_framebuffers: VecDeque<OwnedImage>,
111 disable_mipmaps: bool,
112 mipmapper: MipmapGen,
113 default_frame_options: FrameOptionsWgpu,
114 draw_last_pass_feedback: bool,
115}
116
117pub(crate) struct FilterCommon {
118 pub output_textures: Box<[Option<InputImage>]>,
119 pub feedback_textures: Box<[Option<InputImage>]>,
120 pub history_textures: Box<[Option<InputImage>]>,
121 pub luts: FastHashMap<usize, LutTexture>,
122 pub samplers: SamplerSet,
123 pub config: RuntimeParameters,
124 pub(crate) draw_quad: DrawQuad,
125 pub(crate) device: wgpu::Device,
126 pub(crate) queue: wgpu::Queue,
127}
128
129impl FilterChainWgpu {
130 #[cfg(feature = "native")]
132 pub fn load_from_path(
133 path: impl AsRef<Path>,
134 features: ShaderFeatures,
135 device: &wgpu::Device,
136 queue: &wgpu::Queue,
137 options: Option<&FilterChainOptionsWgpu>,
138 ) -> error::Result<FilterChainWgpu> {
139 let preset = ShaderPreset::try_parse(path, features)?;
141
142 Self::load_from_preset(preset, device, queue, options)
143 }
144
145 #[cfg(feature = "native")]
147 pub fn load_from_preset(
148 preset: ShaderPreset,
149 device: &wgpu::Device,
150 queue: &wgpu::Queue,
151 options: Option<&FilterChainOptionsWgpu>,
152 ) -> error::Result<FilterChainWgpu> {
153 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
154 Self::load_from_pack(preset, device, queue, options)
155 }
156
157 pub fn load_from_pack(
159 preset: ShaderPresetPack,
160 device: &wgpu::Device,
161 queue: &wgpu::Queue,
162 options: Option<&FilterChainOptionsWgpu>,
163 ) -> error::Result<FilterChainWgpu> {
164 let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
165 label: Some("librashader load cmd"),
166 });
167 let filter_chain =
168 Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
169
170 let cmd = cmd.finish();
171
172 let index = queue.submit([cmd]);
174 device.poll(wgpu::PollType::Wait {
175 submission_index: Some(index),
176 timeout: None,
177 })?;
178
179 Ok(filter_chain)
180 }
181
182 #[cfg(feature = "native")]
190 pub fn load_from_preset_deferred(
191 preset: ShaderPreset,
192 device: &wgpu::Device,
193 queue: &wgpu::Queue,
194 cmd: &mut wgpu::CommandEncoder,
195 options: Option<&FilterChainOptionsWgpu>,
196 ) -> error::Result<FilterChainWgpu> {
197 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
198 Self::load_from_pack_deferred(preset, device, queue, cmd, options)
199 }
200
201 pub fn load_from_pack_deferred(
209 preset: ShaderPresetPack,
210 device: &wgpu::Device,
211 queue: &wgpu::Queue,
212 cmd: &mut wgpu::CommandEncoder,
213 options: Option<&FilterChainOptionsWgpu>,
214 ) -> error::Result<FilterChainWgpu> {
215 let config = RuntimeParameters::new(&preset);
216
217 let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
218
219 let disable_cache = options.map_or(true, |o| !o.enable_cache);
221
222 let filters = Self::init_passes(
224 &device,
225 passes,
226 &semantics,
227 options.and_then(|o| o.adapter_info.as_ref()),
228 disable_cache,
229 )?;
230
231 let samplers = SamplerSet::new(&device);
232 let mut mipmapper = MipmapGen::new(&device);
233 let luts = FilterChainWgpu::load_luts(
234 &device,
235 &queue,
236 cmd,
237 &mut mipmapper,
238 &samplers,
239 preset.textures,
240 )?;
241 let framebuffer_gen = || {
243 Ok::<_, error::FilterChainError>(OwnedImage::new(
244 &device,
245 Size::new(1, 1),
246 1,
247 wgpu::TextureFormat::Bgra8Unorm,
248 ))
249 };
250 let input_gen = || None;
251 let framebuffer_init = FramebufferInit::new(
252 filters.iter().map(|f| &f.reflection.meta),
253 &framebuffer_gen,
254 &input_gen,
255 );
256
257 let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
260 let (feedback_framebuffers, feedback_textures) =
263 framebuffer_init.init_feedback_framebuffers()?;
264 let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
267
268 let draw_quad = DrawQuad::new(&device);
269
270 Ok(FilterChainWgpu {
271 draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
272 common: FilterCommon {
273 luts,
274 samplers,
275 config,
276 draw_quad,
277 device: device.clone(),
278 queue: queue.clone(),
279 output_textures,
280 feedback_textures,
281 history_textures,
282 },
283 passes: filters,
284 output_framebuffers,
285 feedback_framebuffers,
286 history_framebuffers,
287 disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
288 mipmapper,
289 default_frame_options: Default::default(),
290 })
291 }
292
293 fn load_luts(
294 device: &wgpu::Device,
295 queue: &wgpu::Queue,
296 cmd: &mut wgpu::CommandEncoder,
297 mipmapper: &mut MipmapGen,
298 sampler_set: &SamplerSet,
299 textures: Vec<TextureResource>,
300 ) -> error::Result<FastHashMap<usize, LutTexture>> {
301 let mut luts = FastHashMap::default();
302
303 #[cfg(not(target_arch = "wasm32"))]
304 let images_iter = textures.into_par_iter();
305
306 #[cfg(target_arch = "wasm32")]
307 let images_iter = textures.into_iter();
308
309 let textures = images_iter
310 .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
311 .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
312 for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
313 let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
314 luts.insert(index, texture);
315 }
316 Ok(luts)
317 }
318
319 fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
320 if let Some(mut back) = self.history_framebuffers.pop_back() {
321 if back.image.size() != input.size() || input.format() != back.image.format() {
322 let _old_back = std::mem::replace(
324 &mut back,
325 OwnedImage::new(
326 &self.common.device,
327 input.size().into(),
328 1,
329 input.format().into(),
330 ),
331 );
332 }
333
334 back.copy_from(&self.common.device, cmd, input);
335
336 self.history_framebuffers.push_front(back)
337 }
338 }
339
340 fn init_passes(
341 device: &wgpu::Device,
342 passes: Vec<ShaderPassMeta>,
343 semantics: &ShaderSemantics,
344 adapter_info: Option<&wgpu::AdapterInfo>,
345 disable_cache: bool,
346 ) -> error::Result<Box<[FilterPass]>> {
347 let filter_creation_fn = || {
348 #[cfg(not(target_arch = "wasm32"))]
349 let passes_iter = passes.into_par_iter();
350 #[cfg(target_arch = "wasm32")]
351 let passes_iter = passes.into_iter();
352
353 let filters: Vec<error::Result<FilterPass>> = passes_iter
354 .enumerate()
355 .map(|(index, (config, mut reflect))| {
356 let reflection = reflect.reflect(index, semantics)?;
357 let wgsl = reflect.compile(NagaLoweringOptions {
358 write_pcb_as_ubo: true,
359 sampler_bind_group: 1,
360 suppress_derivative_uniformity: true,
361 })?;
362
363 let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
364 let push_size = reflection
365 .push_constant
366 .as_ref()
367 .map_or(0, |push| push.size as wgpu::BufferAddress);
368
369 let uniform_storage = UniformStorage::new_with_storage(
370 WgpuStagedBuffer::new(
371 &device,
372 wgpu::BufferUsages::UNIFORM,
373 ubo_size as wgpu::BufferAddress,
374 Some("ubo"),
375 ),
376 WgpuStagedBuffer::new(
377 &device,
378 wgpu::BufferUsages::UNIFORM,
379 push_size as wgpu::BufferAddress,
380 Some("push"),
381 ),
382 );
383
384 let uniform_bindings =
385 reflection.meta.create_binding_map(|param| param.offset());
386
387 let render_pass_format: Option<wgpu::TextureFormat> =
388 if let Some(format) = config.meta.get_format_override() {
389 format.into()
390 } else {
391 config.data.format.into()
392 };
393
394 let graphics_pipeline = WgpuGraphicsPipeline::new(
395 &device,
396 &wgsl,
397 &reflection,
398 render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
399 adapter_info,
400 disable_cache,
401 );
402
403 Ok(FilterPass {
404 reflection,
405 uniform_storage,
406 uniform_bindings,
407 source: config.data,
408 meta: config.meta,
409 graphics_pipeline,
410 })
411 })
412 .collect();
413 filters
414 };
415
416 #[cfg(target_arch = "wasm32")]
417 let filters = filter_creation_fn();
418
419 #[cfg(not(target_arch = "wasm32"))]
420 let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
421 .stack_size(10 * 1048576)
423 .build()
424 {
425 thread_pool.install(|| filter_creation_fn())
426 } else {
427 filter_creation_fn()
428 };
429
430 let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
431 let filters = filters?;
432 Ok(filters.into_boxed_slice())
433 }
434
435 pub fn frame<'a>(
437 &mut self,
438 input: &wgpu::Texture,
439 viewport: &Viewport<WgpuOutputView<'a>>,
440 cmd: &mut wgpu::CommandEncoder,
441 frame_count: usize,
442 options: Option<&FrameOptionsWgpu>,
443 ) -> error::Result<()> {
444 let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
445 let passes = &mut self.passes[0..max];
446
447 if let Some(options) = &options {
448 if options.clear_history {
449 for history in &mut self.history_framebuffers {
450 history.clear(cmd);
451 }
452 }
453 }
454
455 if passes.is_empty() {
456 return Ok(());
457 }
458
459 let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
460
461 let filter = passes[0].meta.filter;
462 let wrap_mode = passes[0].meta.wrap_mode;
463
464 for (texture, image) in self
466 .common
467 .history_textures
468 .iter_mut()
469 .zip(self.history_framebuffers.iter())
470 {
471 *texture = Some(image.as_input(filter, wrap_mode));
472 }
473
474 let original = InputImage {
475 image: input.clone(),
476 view: original_image_view,
477 wrap_mode,
478 filter_mode: filter,
479 mip_filter: filter,
480 };
481
482 let mut source = original.clone();
483
484 let passes_len = passes.len();
485 let options = options.unwrap_or(&self.default_frame_options);
486
487 for index in 0..passes_len {
489 if self.feedback_framebuffers.contains(index) {
490 std::mem::swap(
491 &mut self.output_framebuffers[index],
492 &mut self.feedback_framebuffers[index],
493 );
494 }
495 }
496
497 let scale_context = self.common.device.clone();
498
499 OwnedImage::scale_feedback_framebuffers_with_context(
501 source.image.size().into(),
502 viewport.output.size,
503 original.image.size().into(),
504 &mut self.feedback_framebuffers,
505 passes,
506 &scale_context,
507 |index, pass, feedback| {
508 self.common.feedback_textures[index] =
509 Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
510 Ok(())
511 },
512 )?;
513
514 OwnedImage::scale_output_framebuffers_with_context(
515 source.image.size().into(),
516 viewport.output.size,
517 original.image.size().into(),
518 &mut self.output_framebuffers,
519 passes,
520 &scale_context,
521 |index, pass, target, size| {
522 source.filter_mode = pass.meta.filter;
523 source.wrap_mode = pass.meta.wrap_mode;
524 source.mip_filter = pass.meta.filter;
525 let frame_count_pass = pass.meta.get_frame_count(frame_count);
526
527 if index != passes_len - 1 {
528 let output_image = WgpuOutputView::from(&*target);
529 let out = RenderTarget::identity(&output_image)?;
530
531 pass.draw(
532 cmd,
533 index,
534 &self.common,
535 frame_count_pass,
536 options,
537 viewport,
538 &original,
539 &source,
540 &out,
541 None,
542 QuadType::Offscreen,
543 )?;
544
545 if target.max_miplevels > 1 && !self.disable_mipmaps {
546 let sampler = self.common.samplers.get(
547 WrapMode::ClampToEdge,
548 FilterMode::Linear,
549 FilterMode::Nearest,
550 );
551
552 target.generate_mipmaps(
553 &self.common.device,
554 cmd,
555 &mut self.mipmapper,
556 &sampler,
557 );
558 }
559
560 self.common.output_textures[index] =
561 Some(target.as_input(pass.meta.filter, pass.meta.wrap_mode));
562 source = self.common.output_textures[index].clone().unwrap();
563 return Ok(());
564 }
565
566 if !pass.graphics_pipeline.has_format(viewport.output.format) {
567 pass.graphics_pipeline
569 .recompile(&self.common.device, viewport.output.format);
570 }
571
572 let output_size_override = if self.draw_last_pass_feedback {
579 let output_image = WgpuOutputView::from(&*target);
580 let out = RenderTarget::viewport_with_output(&output_image, viewport);
581
582 pass.draw(
583 cmd,
584 index,
585 &self.common,
586 frame_count_pass,
587 options,
588 viewport,
589 &original,
590 &source,
591 &out,
592 None,
593 QuadType::Final,
594 )?;
595 Some(size)
596 } else {
597 None
598 };
599
600 let out = RenderTarget::viewport(viewport);
601 pass.draw(
602 cmd,
603 index,
604 &self.common,
605 frame_count_pass,
606 options,
607 viewport,
608 &original,
609 &source,
610 &out,
611 output_size_override,
612 QuadType::Final,
613 )?;
614
615 Ok(())
616 },
617 )?;
618
619 self.push_history(&input, cmd);
620 Ok(())
621 }
622}