1use librashader_common::map::FastHashMap;
2use librashader_presets::{ShaderFeatures, ShaderPreset};
3use librashader_reflect::back::targets::WGSL;
4use librashader_reflect::back::{CompileReflectShader, CompileShader};
5use librashader_reflect::front::SpirvCompilation;
6use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
7use librashader_reflect::reflect::semantics::ShaderSemantics;
8use librashader_reflect::reflect::ReflectShader;
9use librashader_runtime::binding::BindingUtil;
10use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection};
11use librashader_runtime::quad::QuadType;
12use librashader_runtime::uniforms::UniformStorage;
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15use std::collections::VecDeque;
16use std::path::Path;
17
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::ThreadPoolBuilder;
20
21use crate::buffer::WgpuStagedBuffer;
22use crate::draw_quad::DrawQuad;
23use librashader_common::{FilterMode, Size, Viewport, WrapMode};
24use librashader_reflect::reflect::naga::{Naga, NagaLoweringOptions};
25use librashader_runtime::framebuffer::{FramebufferInit, FramebufferPool};
26use librashader_runtime::render_target::RenderTarget;
27use librashader_runtime::scaling::ScaleFramebuffer;
28
29use crate::error;
30use crate::error::FilterChainError;
31use crate::filter_pass::FilterPass;
32use crate::framebuffer::WgpuOutputView;
33use crate::graphics_pipeline::WgpuGraphicsPipeline;
34use crate::luts::LutTexture;
35use crate::mipmap::MipmapGen;
36use crate::options::{FilterChainOptionsWgpu, FrameOptionsWgpu};
37use crate::samplers::SamplerSet;
38use crate::texture::{InputImage, OwnedImage};
39
40#[cfg(feature = "native")]
41mod compile {
42 use super::*;
43 use librashader_pack::{PassResource, TextureResource};
44
45 #[cfg(feature = "nightly")]
46 pub type ShaderPassMeta =
47 ShaderPassArtifact<impl CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>;
48
49 #[cfg(not(feature = "nightly"))]
50 pub type ShaderPassMeta =
51 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, SpirvCompilation, Naga> + Send>>;
52
53 #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
54 pub fn compile_passes(
55 shaders: Vec<PassResource>,
56 textures: &[TextureResource],
57 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
58 let (passes, semantics) = WGSL::compile_preset_passes::<
59 SpirvCompilation,
60 Naga,
61 FilterChainError,
62 >(shaders, textures.iter().map(|t| &t.meta))?;
63 Ok((passes, semantics))
64 }
65}
66
67#[cfg(feature = "wgsl_preset_pack")]
68mod compile_wgsl {
69 use super::*;
70 use librashader_pack::{PassResource, TextureResource};
71 use librashader_reflect::front::WgslCompilation;
72
73 #[cfg(feature = "nightly")]
74 pub type ShaderPassMeta =
75 ShaderPassArtifact<impl CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>;
76
77 #[cfg(not(feature = "nightly"))]
78 pub type ShaderPassMeta =
79 ShaderPassArtifact<Box<dyn CompileReflectShader<WGSL, WgslCompilation, Naga> + Send>>;
80
81 #[cfg_attr(feature = "nightly", define_opaque(ShaderPassMeta))]
82 pub fn compile_passes(
83 shaders: Vec<PassResource>,
84 textures: &[TextureResource],
85 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
86 let (passes, semantics) = WGSL::compile_preset_passes::<
87 WgslCompilation,
88 Naga,
89 FilterChainError,
90 >(shaders, textures.iter().map(|t| &t.meta))?;
91 Ok((passes, semantics))
92 }
93}
94
95#[cfg(all(feature = "native", not(feature = "wgsl_preset_pack")))]
96use compile::{compile_passes, ShaderPassMeta};
97
98#[cfg(any(not(feature = "native"), feature = "wgsl_preset_pack"))]
99use compile_wgsl::{compile_passes, ShaderPassMeta};
100
101use librashader_pack::{ShaderPresetPack, TextureResource};
102use librashader_runtime::parameters::RuntimeParameters;
103
104pub struct FilterChainWgpu {
106 pub(crate) common: FilterCommon,
107 passes: Box<[FilterPass]>,
108 output_framebuffers: FramebufferPool<OwnedImage>,
109 feedback_framebuffers: FramebufferPool<OwnedImage>,
110 history_framebuffers: VecDeque<OwnedImage>,
111 disable_mipmaps: bool,
112 mipmapper: MipmapGen,
113 default_frame_options: FrameOptionsWgpu,
114 draw_last_pass_feedback: bool,
115}
116
117pub(crate) struct FilterCommon {
118 pub output_textures: Box<[Option<InputImage>]>,
119 pub feedback_textures: Box<[Option<InputImage>]>,
120 pub history_textures: Box<[Option<InputImage>]>,
121 pub luts: FastHashMap<usize, LutTexture>,
122 pub samplers: SamplerSet,
123 pub config: RuntimeParameters,
124 pub(crate) draw_quad: DrawQuad,
125 pub(crate) device: wgpu::Device,
126 pub(crate) queue: wgpu::Queue,
127}
128
129impl FilterChainWgpu {
130 #[cfg(feature = "native")]
132 pub fn load_from_path(
133 path: impl AsRef<Path>,
134 features: ShaderFeatures,
135 device: &wgpu::Device,
136 queue: &wgpu::Queue,
137 options: Option<&FilterChainOptionsWgpu>,
138 ) -> error::Result<FilterChainWgpu> {
139 let preset = ShaderPreset::try_parse(path, features)?;
141
142 Self::load_from_preset(preset, device, queue, options)
143 }
144
145 #[cfg(feature = "native")]
147 pub fn load_from_preset(
148 preset: ShaderPreset,
149 device: &wgpu::Device,
150 queue: &wgpu::Queue,
151 options: Option<&FilterChainOptionsWgpu>,
152 ) -> error::Result<FilterChainWgpu> {
153 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
154 Self::load_from_pack(preset, device, queue, options)
155 }
156
157 pub fn load_from_pack(
159 preset: ShaderPresetPack,
160 device: &wgpu::Device,
161 queue: &wgpu::Queue,
162 options: Option<&FilterChainOptionsWgpu>,
163 ) -> error::Result<FilterChainWgpu> {
164 let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
165 label: Some("librashader load cmd"),
166 });
167 let filter_chain =
168 Self::load_from_pack_deferred(preset, &device, &queue, &mut cmd, options)?;
169
170 let cmd = cmd.finish();
171
172 let index = queue.submit([cmd]);
174 device.poll(wgpu::PollType::Wait {
175 submission_index: Some(index),
176 timeout: None,
177 })?;
178
179 Ok(filter_chain)
180 }
181
182 #[cfg(feature = "native")]
190 pub fn load_from_preset_deferred(
191 preset: ShaderPreset,
192 device: &wgpu::Device,
193 queue: &wgpu::Queue,
194 cmd: &mut wgpu::CommandEncoder,
195 options: Option<&FilterChainOptionsWgpu>,
196 ) -> error::Result<FilterChainWgpu> {
197 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
198 Self::load_from_pack_deferred(preset, device, queue, cmd, options)
199 }
200
201 pub fn load_from_pack_deferred(
209 preset: ShaderPresetPack,
210 device: &wgpu::Device,
211 queue: &wgpu::Queue,
212 cmd: &mut wgpu::CommandEncoder,
213 options: Option<&FilterChainOptionsWgpu>,
214 ) -> error::Result<FilterChainWgpu> {
215 let config = RuntimeParameters::new(&preset);
216
217 let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
218
219 let disable_cache = options.map_or(true, |o| !o.enable_cache);
221
222 let filters = Self::init_passes(
224 &device,
225 passes,
226 &semantics,
227 options.and_then(|o| o.adapter_info.as_ref()),
228 disable_cache,
229 )?;
230
231 let samplers = SamplerSet::new(&device);
232 let mut mipmapper = MipmapGen::new(&device);
233 let luts = FilterChainWgpu::load_luts(
234 &device,
235 &queue,
236 cmd,
237 &mut mipmapper,
238 &samplers,
239 preset.textures,
240 )?;
241 let framebuffer_gen = || {
243 Ok::<_, error::FilterChainError>(OwnedImage::new(
244 &device,
245 Size::new(1, 1),
246 1,
247 wgpu::TextureFormat::Bgra8Unorm,
248 ))
249 };
250 let input_gen = || None;
251 let framebuffer_init = FramebufferInit::new(
252 filters.iter().map(|f| &f.reflection.meta),
253 &framebuffer_gen,
254 &input_gen,
255 );
256
257 let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
260 let (feedback_framebuffers, feedback_textures) = framebuffer_init.init_feedback_framebuffers()?;
263 let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
266
267 let draw_quad = DrawQuad::new(&device);
268
269 Ok(FilterChainWgpu {
270 draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
271 common: FilterCommon {
272 luts,
273 samplers,
274 config,
275 draw_quad,
276 device: device.clone(),
277 queue: queue.clone(),
278 output_textures,
279 feedback_textures,
280 history_textures,
281 },
282 passes: filters,
283 output_framebuffers,
284 feedback_framebuffers,
285 history_framebuffers,
286 disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
287 mipmapper,
288 default_frame_options: Default::default(),
289 })
290 }
291
292 fn load_luts(
293 device: &wgpu::Device,
294 queue: &wgpu::Queue,
295 cmd: &mut wgpu::CommandEncoder,
296 mipmapper: &mut MipmapGen,
297 sampler_set: &SamplerSet,
298 textures: Vec<TextureResource>,
299 ) -> error::Result<FastHashMap<usize, LutTexture>> {
300 let mut luts = FastHashMap::default();
301
302 #[cfg(not(target_arch = "wasm32"))]
303 let images_iter = textures.into_par_iter();
304
305 #[cfg(target_arch = "wasm32")]
306 let images_iter = textures.into_iter();
307
308 let textures = images_iter
309 .map(|texture| LoadedTexture::from_texture(texture, UVDirection::TopLeft))
310 .collect::<Result<Vec<LoadedTexture>, ImageError>>()?;
311 for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
312 let texture = LutTexture::new(device, queue, cmd, image, &meta, mipmapper, sampler_set);
313 luts.insert(index, texture);
314 }
315 Ok(luts)
316 }
317
318 fn push_history(&mut self, input: &wgpu::Texture, cmd: &mut wgpu::CommandEncoder) {
319 if let Some(mut back) = self.history_framebuffers.pop_back() {
320 if back.image.size() != input.size() || input.format() != back.image.format() {
321 let _old_back = std::mem::replace(
323 &mut back,
324 OwnedImage::new(
325 &self.common.device,
326 input.size().into(),
327 1,
328 input.format().into(),
329 ),
330 );
331 }
332
333 back.copy_from(&self.common.device, cmd, input);
334
335 self.history_framebuffers.push_front(back)
336 }
337 }
338
339 fn init_passes(
340 device: &wgpu::Device,
341 passes: Vec<ShaderPassMeta>,
342 semantics: &ShaderSemantics,
343 adapter_info: Option<&wgpu::AdapterInfo>,
344 disable_cache: bool,
345 ) -> error::Result<Box<[FilterPass]>> {
346 let filter_creation_fn = || {
347 #[cfg(not(target_arch = "wasm32"))]
348 let passes_iter = passes.into_par_iter();
349 #[cfg(target_arch = "wasm32")]
350 let passes_iter = passes.into_iter();
351
352 let filters: Vec<error::Result<FilterPass>> = passes_iter
353 .enumerate()
354 .map(|(index, (config, mut reflect))| {
355 let reflection = reflect.reflect(index, semantics)?;
356 let wgsl = reflect.compile(NagaLoweringOptions {
357 write_pcb_as_ubo: true,
358 sampler_bind_group: 1,
359 suppress_derivative_uniformity: true,
360 })?;
361
362 let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
363 let push_size = reflection
364 .push_constant
365 .as_ref()
366 .map_or(0, |push| push.size as wgpu::BufferAddress);
367
368 let uniform_storage = UniformStorage::new_with_storage(
369 WgpuStagedBuffer::new(
370 &device,
371 wgpu::BufferUsages::UNIFORM,
372 ubo_size as wgpu::BufferAddress,
373 Some("ubo"),
374 ),
375 WgpuStagedBuffer::new(
376 &device,
377 wgpu::BufferUsages::UNIFORM,
378 push_size as wgpu::BufferAddress,
379 Some("push"),
380 ),
381 );
382
383 let uniform_bindings =
384 reflection.meta.create_binding_map(|param| param.offset());
385
386 let render_pass_format: Option<wgpu::TextureFormat> =
387 if let Some(format) = config.meta.get_format_override() {
388 format.into()
389 } else {
390 config.data.format.into()
391 };
392
393 let graphics_pipeline = WgpuGraphicsPipeline::new(
394 &device,
395 &wgsl,
396 &reflection,
397 render_pass_format.unwrap_or(wgpu::TextureFormat::Rgba8Unorm),
398 adapter_info,
399 disable_cache,
400 );
401
402 Ok(FilterPass {
403 reflection,
404 uniform_storage,
405 uniform_bindings,
406 source: config.data,
407 meta: config.meta,
408 graphics_pipeline,
409 })
410 })
411 .collect();
412 filters
413 };
414
415 #[cfg(target_arch = "wasm32")]
416 let filters = filter_creation_fn();
417
418 #[cfg(not(target_arch = "wasm32"))]
419 let filters = if let Ok(thread_pool) = ThreadPoolBuilder::new()
420 .stack_size(10 * 1048576)
422 .build()
423 {
424 thread_pool.install(|| filter_creation_fn())
425 } else {
426 filter_creation_fn()
427 };
428
429 let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
430 let filters = filters?;
431 Ok(filters.into_boxed_slice())
432 }
433
434 pub fn frame<'a>(
436 &mut self,
437 input: &wgpu::Texture,
438 viewport: &Viewport<WgpuOutputView<'a>>,
439 cmd: &mut wgpu::CommandEncoder,
440 frame_count: usize,
441 options: Option<&FrameOptionsWgpu>,
442 ) -> error::Result<()> {
443 let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
444 let passes = &mut self.passes[0..max];
445
446 if let Some(options) = &options {
447 if options.clear_history {
448 for history in &mut self.history_framebuffers {
449 history.clear(cmd);
450 }
451 }
452 }
453
454 if passes.is_empty() {
455 return Ok(());
456 }
457
458 let original_image_view = input.create_view(&wgpu::TextureViewDescriptor::default());
459
460 let filter = passes[0].meta.filter;
461 let wrap_mode = passes[0].meta.wrap_mode;
462
463 for (texture, image) in self
465 .common
466 .history_textures
467 .iter_mut()
468 .zip(self.history_framebuffers.iter())
469 {
470 *texture = Some(image.as_input(filter, wrap_mode));
471 }
472
473 let original = InputImage {
474 image: input.clone(),
475 view: original_image_view,
476 wrap_mode,
477 filter_mode: filter,
478 mip_filter: filter,
479 };
480
481 let mut source = original.clone();
482
483 let passes_len = passes.len();
484 let options = options.unwrap_or(&self.default_frame_options);
485
486 for index in 0..passes_len {
488 if self.feedback_framebuffers.contains(index) {
489 std::mem::swap(&mut self.output_framebuffers[index], &mut self.feedback_framebuffers[index]);
490 }
491 }
492
493 let scale_context = self.common.device.clone();
494
495 OwnedImage::scale_feedback_framebuffers_with_context(
497 source.image.size().into(),
498 viewport.output.size,
499 original.image.size().into(),
500 &mut self.feedback_framebuffers,
501 passes,
502 &scale_context,
503 |index, pass, feedback| {
504 self.common.feedback_textures[index] =
505 Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode));
506 Ok(())
507 },
508 )?;
509
510 OwnedImage::scale_output_framebuffers_with_context(
511 source.image.size().into(),
512 viewport.output.size,
513 original.image.size().into(),
514 &mut self.output_framebuffers,
515 passes,
516 &scale_context,
517 |index, pass, target, size| {
518 source.filter_mode = pass.meta.filter;
519 source.wrap_mode = pass.meta.wrap_mode;
520 source.mip_filter = pass.meta.filter;
521 let frame_count_pass = pass.meta.get_frame_count(frame_count);
522
523 if index != passes_len - 1 {
524 let output_image = WgpuOutputView::from(&*target);
525 let out = RenderTarget::identity(&output_image)?;
526
527 pass.draw(
528 cmd,
529 index,
530 &self.common,
531 frame_count_pass,
532 options,
533 viewport,
534 &original,
535 &source,
536 &out,
537 None,
538 QuadType::Offscreen,
539 )?;
540
541 if target.max_miplevels > 1 && !self.disable_mipmaps {
542 let sampler = self.common.samplers.get(
543 WrapMode::ClampToEdge,
544 FilterMode::Linear,
545 FilterMode::Nearest,
546 );
547
548 target.generate_mipmaps(
549 &self.common.device,
550 cmd,
551 &mut self.mipmapper,
552 &sampler,
553 );
554 }
555
556 self.common.output_textures[index] =
557 Some(target.as_input(pass.meta.filter, pass.meta.wrap_mode));
558 source = self.common.output_textures[index].clone().unwrap();
559 return Ok(());
560 }
561
562 if !pass.graphics_pipeline.has_format(viewport.output.format) {
563 pass.graphics_pipeline
565 .recompile(&self.common.device, viewport.output.format);
566 }
567
568 let output_size_override = if self.draw_last_pass_feedback {
575 let output_image = WgpuOutputView::from(&*target);
576 let out = RenderTarget::viewport_with_output(&output_image, viewport);
577
578 pass.draw(
579 cmd,
580 index,
581 &self.common,
582 frame_count_pass,
583 options,
584 viewport,
585 &original,
586 &source,
587 &out,
588 None,
589 QuadType::Final,
590 )?;
591 Some(size)
592 } else {
593 None
594 };
595
596 let out = RenderTarget::viewport(viewport);
597 pass.draw(
598 cmd,
599 index,
600 &self.common,
601 frame_count_pass,
602 options,
603 viewport,
604 &original,
605 &source,
606 &out,
607 output_size_override,
608 QuadType::Final,
609 )?;
610
611 Ok(())
612 },
613 )?;
614
615 self.push_history(&input, cmd);
616 Ok(())
617 }
618}