1use crate::buffer::MetalBuffer;
2use crate::draw_quad::DrawQuad;
3use crate::error;
4use crate::error::FilterChainError;
5use crate::filter_pass::FilterPass;
6use crate::graphics_pipeline::MetalGraphicsPipeline;
7use crate::luts::LutTexture;
8use crate::options::{FilterChainOptionsMetal, FrameOptionsMetal};
9use crate::samplers::SamplerSet;
10use crate::texture::{get_texture_size, InputTexture, MetalTextureRef, OwnedTexture};
11use librashader_common::map::FastHashMap;
12use librashader_common::{ImageFormat, Size, Viewport};
13use librashader_presets::context::VideoDriver;
14use librashader_presets::{ShaderFeatures, ShaderPreset};
15use librashader_reflect::back::msl::MslVersion;
16use librashader_reflect::back::targets::MSL;
17use librashader_reflect::back::{CompileReflectShader, CompileShader};
18use librashader_reflect::front::SpirvCompilation;
19use librashader_reflect::reflect::cross::SpirvCross;
20use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
21use librashader_reflect::reflect::semantics::ShaderSemantics;
22use librashader_reflect::reflect::ReflectShader;
23use librashader_runtime::binding::BindingUtil;
24use librashader_runtime::framebuffer::FramebufferInit;
25use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection, BGRA8};
26use librashader_runtime::quad::QuadType;
27use librashader_runtime::render_target::RenderTarget;
28use librashader_runtime::scaling::ScaleFramebuffer;
29use librashader_runtime::uniforms::UniformStorage;
30use objc2::rc::Retained;
31use objc2::runtime::ProtocolObject;
32use objc2_foundation::NSString;
33use objc2_metal::{
34 MTLCommandBuffer, MTLCommandEncoder, MTLCommandQueue, MTLDevice, MTLLoadAction, MTLPixelFormat,
35 MTLRenderPassDescriptor, MTLResource, MTLStoreAction, MTLTexture,
36};
37use rayon::prelude::*;
38use std::collections::VecDeque;
39use std::fmt::{Debug, Formatter};
40use std::path::Path;
41
42mod compile {
43 use super::*;
44 use librashader_pack::{PassResource, TextureResource};
45
46 #[cfg(not(feature = "stable"))]
47 pub type ShaderPassMeta =
48 ShaderPassArtifact<impl CompileReflectShader<MSL, SpirvCompilation, SpirvCross> + Send>;
49
50 #[cfg(feature = "stable")]
51 pub type ShaderPassMeta =
52 ShaderPassArtifact<Box<dyn CompileReflectShader<MSL, SpirvCompilation, SpirvCross> + Send>>;
53
54 #[cfg_attr(not(feature = "stable"), define_opaque(ShaderPassMeta))]
55 pub fn compile_passes(
56 shaders: Vec<PassResource>,
57 textures: &[TextureResource],
58 ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
59 let (passes, semantics) = MSL::compile_preset_passes::<
60 SpirvCompilation,
61 SpirvCross,
62 FilterChainError,
63 >(shaders, textures.iter().map(|t| &t.meta))?;
64 Ok((passes, semantics))
65 }
66}
67
68use compile::{compile_passes, ShaderPassMeta};
69use librashader_pack::{ShaderPresetPack, TextureResource};
70use librashader_runtime::parameters::RuntimeParameters;
71
72pub struct FilterChainMetal {
74 pub(crate) common: FilterCommon,
75 passes: Box<[FilterPass]>,
76 output_framebuffers: Box<[OwnedTexture]>,
77 feedback_framebuffers: Box<[OwnedTexture]>,
78 history_framebuffers: VecDeque<OwnedTexture>,
79 prev_frame_history_buffer: OwnedTexture,
87 disable_mipmaps: bool,
88 default_options: FrameOptionsMetal,
89 draw_last_pass_feedback: bool,
90}
91
92impl Debug for FilterChainMetal {
93 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94 f.write_fmt(format_args!("FilterChainMetal"))
95 }
96}
97
98pub(crate) struct FilterCommon {
99 pub output_textures: Box<[Option<InputTexture>]>,
100 pub feedback_textures: Box<[Option<InputTexture>]>,
101 pub history_textures: Box<[Option<InputTexture>]>,
102 pub luts: FastHashMap<usize, LutTexture>,
103 pub samplers: SamplerSet,
104 pub config: RuntimeParameters,
105 pub(crate) draw_quad: DrawQuad,
106 device: Retained<ProtocolObject<dyn MTLDevice>>,
107}
108
109impl FilterChainMetal {
110 pub fn load_from_path(
112 path: impl AsRef<Path>,
113 features: ShaderFeatures,
114 queue: &ProtocolObject<dyn MTLCommandQueue>,
115 options: Option<&FilterChainOptionsMetal>,
116 ) -> error::Result<FilterChainMetal> {
117 let preset =
119 ShaderPreset::try_parse_with_driver_context(path, features, VideoDriver::Metal)?;
120 Self::load_from_preset(preset, queue, options)
121 }
122
123 pub fn load_from_preset(
125 preset: ShaderPreset,
126 queue: &ProtocolObject<dyn MTLCommandQueue>,
127 options: Option<&FilterChainOptionsMetal>,
128 ) -> error::Result<FilterChainMetal> {
129 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
130 Self::load_from_pack(preset, queue, options)
131 }
132
133 pub fn load_from_pack(
135 preset: ShaderPresetPack,
136 queue: &ProtocolObject<dyn MTLCommandQueue>,
137 options: Option<&FilterChainOptionsMetal>,
138 ) -> error::Result<FilterChainMetal> {
139 let cmd = queue
140 .commandBuffer()
141 .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
142
143 let filter_chain =
144 Self::load_from_pack_deferred_internal(preset, queue.device(), &cmd, options)?;
145
146 cmd.commit();
147 unsafe { cmd.waitUntilCompleted() };
148
149 Ok(filter_chain)
150 }
151
152 fn load_luts(
153 device: &ProtocolObject<dyn MTLDevice>,
154 cmd: &ProtocolObject<dyn MTLCommandBuffer>,
155 textures: Vec<TextureResource>,
156 ) -> error::Result<FastHashMap<usize, LutTexture>> {
157 let mut luts = FastHashMap::default();
158
159 let mipmapper = cmd
160 .blitCommandEncoder()
161 .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
162
163 let textures = textures
164 .into_par_iter()
165 .map(|texture| LoadedTexture::<BGRA8>::from_texture(texture, UVDirection::TopLeft))
166 .collect::<Result<Vec<LoadedTexture<BGRA8>>, ImageError>>()?;
167 for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
168 let texture = LutTexture::new(device, image, &meta, &mipmapper)?;
169 luts.insert(index, texture);
170 }
171
172 mipmapper.endEncoding();
173 Ok(luts)
174 }
175
176 fn init_passes(
177 device: &Retained<ProtocolObject<dyn MTLDevice>>,
178 passes: Vec<ShaderPassMeta>,
179 semantics: &ShaderSemantics,
180 ) -> error::Result<Box<[FilterPass]>> {
181 let filters: Vec<error::Result<FilterPass>> = passes
183 .into_iter()
184 .enumerate()
185 .map(|(index, (config, mut reflect))| {
186 let reflection = reflect.reflect(index, semantics)?;
187 let msl = reflect.compile(Some(MslVersion::new(2, 0, 0)))?;
188
189 let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
190 let push_size = reflection
191 .push_constant
192 .as_ref()
193 .map_or(0, |push| push.size);
194
195 let uniform_storage = UniformStorage::new_with_storage(
196 MetalBuffer::new(&device, ubo_size, "ubo")?,
197 MetalBuffer::new(&device, push_size as usize, "pcb")?,
198 );
199
200 let uniform_bindings = reflection.meta.create_binding_map(|param| param.offset());
201
202 let render_pass_format: MTLPixelFormat =
203 if let Some(format) = config.meta.get_format_override() {
204 format.into()
205 } else {
206 config.data.format.into()
207 };
208
209 let graphics_pipeline = MetalGraphicsPipeline::new(
210 &device,
211 &msl,
212 if render_pass_format == MTLPixelFormat(0) {
213 MTLPixelFormat::RGBA8Unorm
214 } else {
215 render_pass_format
216 },
217 )?;
218
219 Ok(FilterPass {
220 reflection,
221 uniform_storage,
222 uniform_bindings,
223 source: config.data,
224 meta: config.meta,
225 graphics_pipeline,
226 })
227 })
228 .collect();
229 let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
231 let filters = filters?;
232 Ok(filters.into_boxed_slice())
233 }
234
235 fn push_history(
236 &mut self,
237 cmd: &ProtocolObject<dyn MTLCommandBuffer>,
238 input: &ProtocolObject<dyn MTLTexture>,
239 ) -> error::Result<()> {
240 let Some(mut back) = self.history_framebuffers.pop_back() else {
242 return Ok(());
243 };
244
245 std::mem::swap(&mut back, &mut self.prev_frame_history_buffer);
247 self.history_framebuffers.push_front(back);
248
249 let back = &mut self.prev_frame_history_buffer;
252 let mipmapper = cmd
253 .blitCommandEncoder()
254 .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
255 if back.texture.height() != input.height()
256 || back.texture.width() != input.width()
257 || input.pixelFormat() != back.texture.pixelFormat()
258 {
259 let size = Size {
260 width: input.width() as u32,
261 height: input.height() as u32,
262 };
263
264 let _old_back = std::mem::replace(
265 back,
266 OwnedTexture::new(&self.common.device, size, 1, input.pixelFormat())?,
267 );
268 }
269
270 back.copy_from(&mipmapper, input)?;
271 mipmapper.endEncoding();
272 Ok(())
273 }
274
275 pub fn load_from_preset_deferred(
283 preset: ShaderPreset,
284 queue: &ProtocolObject<dyn MTLCommandQueue>,
285 cmd: &ProtocolObject<dyn MTLCommandBuffer>,
286 options: Option<&FilterChainOptionsMetal>,
287 ) -> error::Result<FilterChainMetal> {
288 let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
289 Self::load_from_pack_deferred(preset, queue, cmd, options)
290 }
291
292 pub fn load_from_pack_deferred(
300 preset: ShaderPresetPack,
301 queue: &ProtocolObject<dyn MTLCommandQueue>,
302 cmd: &ProtocolObject<dyn MTLCommandBuffer>,
303 options: Option<&FilterChainOptionsMetal>,
304 ) -> error::Result<FilterChainMetal> {
305 Self::load_from_pack_deferred_internal(preset, queue.device(), &cmd, options)
306 }
307
308 fn load_from_pack_deferred_internal(
316 preset: ShaderPresetPack,
317 device: Retained<ProtocolObject<dyn MTLDevice>>,
318 cmd: &ProtocolObject<dyn MTLCommandBuffer>,
319 options: Option<&FilterChainOptionsMetal>,
320 ) -> error::Result<FilterChainMetal> {
321 let config = RuntimeParameters::new(&preset);
322 let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
323
324 let filters = Self::init_passes(&device, passes, &semantics)?;
325
326 let samplers = SamplerSet::new(&device)?;
327 let luts = FilterChainMetal::load_luts(&device, &cmd, preset.textures)?;
328 let framebuffer_gen = || {
329 Ok::<_, error::FilterChainError>(OwnedTexture::new(
330 &device,
331 Size::new(1, 1),
332 1,
333 ImageFormat::R8G8B8A8Unorm.into(),
334 )?)
335 };
336 let input_gen = || None;
337 let framebuffer_init = FramebufferInit::new(
338 filters.iter().map(|f| &f.reflection.meta),
339 &framebuffer_gen,
340 &input_gen,
341 );
342 let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
343 let (feedback_framebuffers, feedback_textures) =
346 framebuffer_init.init_output_framebuffers()?;
347 let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
350
351 let history_buffer = framebuffer_gen()?;
352
353 let draw_quad = DrawQuad::new(&device)?;
354 Ok(FilterChainMetal {
355 draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
356 common: FilterCommon {
357 luts,
358 samplers,
359 config,
360 draw_quad,
361 device,
362 output_textures,
363 feedback_textures,
364 history_textures,
365 },
366 passes: filters,
367 output_framebuffers,
368 feedback_framebuffers,
369 history_framebuffers,
370 prev_frame_history_buffer: history_buffer,
371 disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
372 default_options: Default::default(),
373 })
374 }
375
376 pub fn frame(
380 &mut self,
381 input: &ProtocolObject<dyn MTLTexture>,
382 viewport: &Viewport<MetalTextureRef>,
383 cmd: &ProtocolObject<dyn MTLCommandBuffer>,
384 frame_count: usize,
385 options: Option<&FrameOptionsMetal>,
386 ) -> error::Result<()> {
387 let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
388 if let Some(options) = &options {
389 let clear_desc = unsafe { MTLRenderPassDescriptor::new() };
390 if options.clear_history {
391 for (index, history) in self.history_framebuffers.iter().enumerate() {
392 unsafe {
393 let ca = clear_desc
394 .colorAttachments()
395 .objectAtIndexedSubscript(index);
396 ca.setTexture(Some(&history.texture));
397 ca.setLoadAction(MTLLoadAction::Clear);
398 ca.setStoreAction(MTLStoreAction::Store);
399 }
400 }
401
402 let clearpass = cmd
403 .renderCommandEncoderWithDescriptor(&clear_desc)
404 .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
405 clearpass.endEncoding();
406 }
407 }
408
409 self.push_history(&cmd, &input)?;
410
411 let passes = &mut self.passes[0..max];
412 if passes.is_empty() {
413 return Ok(());
414 }
415
416 let filter = passes[0].meta.filter;
417 let wrap_mode = passes[0].meta.wrap_mode;
418
419 for (texture, image) in self
421 .common
422 .history_textures
423 .iter_mut()
424 .zip(self.history_framebuffers.iter())
425 {
426 *texture = Some(image.as_input(filter, wrap_mode)?);
427 }
428
429 let original = InputTexture {
430 texture: input
431 .newTextureViewWithPixelFormat(input.pixelFormat())
432 .ok_or(FilterChainError::FailedToCreateTexture)?,
433 wrap_mode,
434 filter_mode: filter,
435 mip_filter: filter,
436 };
437
438 let mut source = original.try_clone()?;
439
440 source
441 .texture
442 .setLabel(Some(&*NSString::from_str("librashader_sourcetex")));
443
444 std::mem::swap(
446 &mut self.output_framebuffers,
447 &mut self.feedback_framebuffers,
448 );
449
450 OwnedTexture::scale_framebuffers_with_context(
452 get_texture_size(&source.texture).into(),
453 get_texture_size(viewport.output),
454 get_texture_size(&original.texture).into(),
455 &mut self.output_framebuffers,
456 &mut self.feedback_framebuffers,
457 passes,
458 &self.common.device,
459 Some(&mut |index: usize,
460 pass: &FilterPass,
461 output: &OwnedTexture,
462 feedback: &OwnedTexture| {
463 self.common.feedback_textures[index] =
465 Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode)?);
466 self.common.output_textures[index] =
467 Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode)?);
468 Ok(())
469 }),
470 )?;
471
472 let passes_len = passes.len();
473 let (pass, last) = passes.split_at_mut(passes_len - 1);
474 let options = options.unwrap_or(&self.default_options);
475
476 for (index, pass) in pass.iter_mut().enumerate() {
477 let target = &self.output_framebuffers[index];
478 source.filter_mode = pass.meta.filter;
479 source.wrap_mode = pass.meta.wrap_mode;
480 source.mip_filter = pass.meta.filter;
481
482 let out = RenderTarget::identity(target.texture.as_ref())?;
483 pass.draw(
484 &cmd,
485 index,
486 &self.common,
487 pass.meta.get_frame_count(frame_count),
488 options,
489 viewport,
490 &original,
491 &source,
492 &out,
493 QuadType::Offscreen,
494 )?;
495
496 if target.max_miplevels > 1 && !self.disable_mipmaps {
497 target.generate_mipmaps(&cmd)?;
498 }
499
500 self.common.output_textures[index] =
501 Some(target.as_input(pass.meta.filter, pass.meta.wrap_mode)?);
502 source = self.common.output_textures[index]
503 .as_ref()
504 .map(InputTexture::try_clone)
505 .unwrap()?;
506 }
507
508 assert_eq!(last.len(), 1);
510
511 if let Some(pass) = last.iter_mut().next() {
512 if !pass
513 .graphics_pipeline
514 .has_format(viewport.output.pixelFormat())
515 {
516 pass.graphics_pipeline
518 .recompile(&self.common.device, viewport.output.pixelFormat())?;
519 }
520
521 source.filter_mode = pass.meta.filter;
522 source.wrap_mode = pass.meta.wrap_mode;
523 source.mip_filter = pass.meta.filter;
524 let index = passes_len - 1;
525
526 if self.draw_last_pass_feedback {
527 let output_image = &self.output_framebuffers[index].texture;
528 let out = RenderTarget::viewport_with_output(output_image.as_ref(), viewport);
529 pass.draw(
530 &cmd,
531 passes_len - 1,
532 &self.common,
533 pass.meta.get_frame_count(frame_count),
534 options,
535 viewport,
536 &original,
537 &source,
538 &out,
539 QuadType::Final,
540 )?;
541 }
542
543 let out = RenderTarget::viewport(viewport);
544 pass.draw(
545 &cmd,
546 index,
547 &self.common,
548 pass.meta.get_frame_count(frame_count),
549 options,
550 viewport,
551 &original,
552 &source,
553 &out,
554 QuadType::Final,
555 )?;
556 }
557
558 Ok(())
559 }
560}