librashader_runtime_mtl/
filter_chain.rs

1use crate::buffer::MetalBuffer;
2use crate::draw_quad::DrawQuad;
3use crate::error;
4use crate::error::FilterChainError;
5use crate::filter_pass::FilterPass;
6use crate::graphics_pipeline::MetalGraphicsPipeline;
7use crate::luts::LutTexture;
8use crate::options::{FilterChainOptionsMetal, FrameOptionsMetal};
9use crate::samplers::SamplerSet;
10use crate::texture::{get_texture_size, InputTexture, MetalTextureRef, OwnedTexture};
11use librashader_common::map::FastHashMap;
12use librashader_common::{ImageFormat, Size, Viewport};
13use librashader_presets::context::VideoDriver;
14use librashader_presets::{ShaderFeatures, ShaderPreset};
15use librashader_reflect::back::msl::MslVersion;
16use librashader_reflect::back::targets::MSL;
17use librashader_reflect::back::{CompileReflectShader, CompileShader};
18use librashader_reflect::front::SpirvCompilation;
19use librashader_reflect::reflect::cross::SpirvCross;
20use librashader_reflect::reflect::presets::{CompilePresetTarget, ShaderPassArtifact};
21use librashader_reflect::reflect::semantics::ShaderSemantics;
22use librashader_reflect::reflect::ReflectShader;
23use librashader_runtime::binding::BindingUtil;
24use librashader_runtime::framebuffer::FramebufferInit;
25use librashader_runtime::image::{ImageError, LoadedTexture, UVDirection, BGRA8};
26use librashader_runtime::quad::QuadType;
27use librashader_runtime::render_target::RenderTarget;
28use librashader_runtime::scaling::ScaleFramebuffer;
29use librashader_runtime::uniforms::UniformStorage;
30use objc2::rc::Retained;
31use objc2::runtime::ProtocolObject;
32use objc2_foundation::NSString;
33use objc2_metal::{
34    MTLCommandBuffer, MTLCommandEncoder, MTLCommandQueue, MTLDevice, MTLLoadAction, MTLPixelFormat,
35    MTLRenderPassDescriptor, MTLResource, MTLStoreAction, MTLTexture,
36};
37use rayon::prelude::*;
38use std::collections::VecDeque;
39use std::fmt::{Debug, Formatter};
40use std::path::Path;
41
42mod compile {
43    use super::*;
44    use librashader_pack::{PassResource, TextureResource};
45
46    #[cfg(not(feature = "stable"))]
47    pub type ShaderPassMeta =
48        ShaderPassArtifact<impl CompileReflectShader<MSL, SpirvCompilation, SpirvCross> + Send>;
49
50    #[cfg(feature = "stable")]
51    pub type ShaderPassMeta =
52        ShaderPassArtifact<Box<dyn CompileReflectShader<MSL, SpirvCompilation, SpirvCross> + Send>>;
53
54    #[cfg_attr(not(feature = "stable"), define_opaque(ShaderPassMeta))]
55    pub fn compile_passes(
56        shaders: Vec<PassResource>,
57        textures: &[TextureResource],
58    ) -> Result<(Vec<ShaderPassMeta>, ShaderSemantics), FilterChainError> {
59        let (passes, semantics) = MSL::compile_preset_passes::<
60            SpirvCompilation,
61            SpirvCross,
62            FilterChainError,
63        >(shaders, textures.iter().map(|t| &t.meta))?;
64        Ok((passes, semantics))
65    }
66}
67
68use compile::{compile_passes, ShaderPassMeta};
69use librashader_pack::{ShaderPresetPack, TextureResource};
70use librashader_runtime::parameters::RuntimeParameters;
71
72/// A Metal filter chain.
73pub struct FilterChainMetal {
74    pub(crate) common: FilterCommon,
75    passes: Box<[FilterPass]>,
76    output_framebuffers: Box<[OwnedTexture]>,
77    feedback_framebuffers: Box<[OwnedTexture]>,
78    history_framebuffers: VecDeque<OwnedTexture>,
79    /// Metal does not allow us to push the input texture to history
80    /// before recording framebuffers, so we double-buffer it.
81    ///
82    /// First we swap OriginalHistory1 with the contents of this buffer (which were written to
83    /// in the previous frame)
84    ///
85    /// Then we blit the original to the buffer.
86    prev_frame_history_buffer: OwnedTexture,
87    disable_mipmaps: bool,
88    default_options: FrameOptionsMetal,
89    draw_last_pass_feedback: bool,
90}
91
92impl Debug for FilterChainMetal {
93    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94        f.write_fmt(format_args!("FilterChainMetal"))
95    }
96}
97
98pub(crate) struct FilterCommon {
99    pub output_textures: Box<[Option<InputTexture>]>,
100    pub feedback_textures: Box<[Option<InputTexture>]>,
101    pub history_textures: Box<[Option<InputTexture>]>,
102    pub luts: FastHashMap<usize, LutTexture>,
103    pub samplers: SamplerSet,
104    pub config: RuntimeParameters,
105    pub(crate) draw_quad: DrawQuad,
106    device: Retained<ProtocolObject<dyn MTLDevice>>,
107}
108
109impl FilterChainMetal {
110    /// Load the shader preset at the given path into a filter chain.
111    pub fn load_from_path(
112        path: impl AsRef<Path>,
113        features: ShaderFeatures,
114        queue: &ProtocolObject<dyn MTLCommandQueue>,
115        options: Option<&FilterChainOptionsMetal>,
116    ) -> error::Result<FilterChainMetal> {
117        // load passes from preset
118        let preset =
119            ShaderPreset::try_parse_with_driver_context(path, features, VideoDriver::Metal)?;
120        Self::load_from_preset(preset, queue, options)
121    }
122
123    /// Load a filter chain from a pre-parsed `ShaderPreset`.
124    pub fn load_from_preset(
125        preset: ShaderPreset,
126        queue: &ProtocolObject<dyn MTLCommandQueue>,
127        options: Option<&FilterChainOptionsMetal>,
128    ) -> error::Result<FilterChainMetal> {
129        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
130        Self::load_from_pack(preset, queue, options)
131    }
132
133    /// Load a filter chain from a pre-parsed `ShaderPreset`.
134    pub fn load_from_pack(
135        preset: ShaderPresetPack,
136        queue: &ProtocolObject<dyn MTLCommandQueue>,
137        options: Option<&FilterChainOptionsMetal>,
138    ) -> error::Result<FilterChainMetal> {
139        let cmd = queue
140            .commandBuffer()
141            .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
142
143        let filter_chain =
144            Self::load_from_pack_deferred_internal(preset, queue.device(), &cmd, options)?;
145
146        cmd.commit();
147        unsafe { cmd.waitUntilCompleted() };
148
149        Ok(filter_chain)
150    }
151
152    fn load_luts(
153        device: &ProtocolObject<dyn MTLDevice>,
154        cmd: &ProtocolObject<dyn MTLCommandBuffer>,
155        textures: Vec<TextureResource>,
156    ) -> error::Result<FastHashMap<usize, LutTexture>> {
157        let mut luts = FastHashMap::default();
158
159        let mipmapper = cmd
160            .blitCommandEncoder()
161            .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
162
163        let textures = textures
164            .into_par_iter()
165            .map(|texture| LoadedTexture::<BGRA8>::from_texture(texture, UVDirection::TopLeft))
166            .collect::<Result<Vec<LoadedTexture<BGRA8>>, ImageError>>()?;
167        for (index, LoadedTexture { meta, image }) in textures.into_iter().enumerate() {
168            let texture = LutTexture::new(device, image, &meta, &mipmapper)?;
169            luts.insert(index, texture);
170        }
171
172        mipmapper.endEncoding();
173        Ok(luts)
174    }
175
176    fn init_passes(
177        device: &Retained<ProtocolObject<dyn MTLDevice>>,
178        passes: Vec<ShaderPassMeta>,
179        semantics: &ShaderSemantics,
180    ) -> error::Result<Box<[FilterPass]>> {
181        // todo: fix this to allow send
182        let filters: Vec<error::Result<FilterPass>> = passes
183            .into_iter()
184            .enumerate()
185            .map(|(index, (config, mut reflect))| {
186                let reflection = reflect.reflect(index, semantics)?;
187                let msl = reflect.compile(Some(MslVersion::new(2, 0, 0)))?;
188
189                let ubo_size = reflection.ubo.as_ref().map_or(0, |ubo| ubo.size as usize);
190                let push_size = reflection
191                    .push_constant
192                    .as_ref()
193                    .map_or(0, |push| push.size);
194
195                let uniform_storage = UniformStorage::new_with_storage(
196                    MetalBuffer::new(&device, ubo_size, "ubo")?,
197                    MetalBuffer::new(&device, push_size as usize, "pcb")?,
198                );
199
200                let uniform_bindings = reflection.meta.create_binding_map(|param| param.offset());
201
202                let render_pass_format: MTLPixelFormat =
203                    if let Some(format) = config.meta.get_format_override() {
204                        format.into()
205                    } else {
206                        config.data.format.into()
207                    };
208
209                let graphics_pipeline = MetalGraphicsPipeline::new(
210                    &device,
211                    &msl,
212                    if render_pass_format == MTLPixelFormat(0) {
213                        MTLPixelFormat::RGBA8Unorm
214                    } else {
215                        render_pass_format
216                    },
217                )?;
218
219                Ok(FilterPass {
220                    reflection,
221                    uniform_storage,
222                    uniform_bindings,
223                    source: config.data,
224                    meta: config.meta,
225                    graphics_pipeline,
226                })
227            })
228            .collect();
229        //
230        let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
231        let filters = filters?;
232        Ok(filters.into_boxed_slice())
233    }
234
235    fn push_history(
236        &mut self,
237        cmd: &ProtocolObject<dyn MTLCommandBuffer>,
238        input: &ProtocolObject<dyn MTLTexture>,
239    ) -> error::Result<()> {
240        // If there's no history, there's no need to do any of this.
241        let Some(mut back) = self.history_framebuffers.pop_back() else {
242            return Ok(());
243        };
244
245        // Push the previous frame as OriginalHistory1
246        std::mem::swap(&mut back, &mut self.prev_frame_history_buffer);
247        self.history_framebuffers.push_front(back);
248
249        // Copy the current frame into prev_frame_history_buffer, which will be
250        // pushed to OriginalHistory1 in the next frame.
251        let back = &mut self.prev_frame_history_buffer;
252        let mipmapper = cmd
253            .blitCommandEncoder()
254            .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
255        if back.texture.height() != input.height()
256            || back.texture.width() != input.width()
257            || input.pixelFormat() != back.texture.pixelFormat()
258        {
259            let size = Size {
260                width: input.width() as u32,
261                height: input.height() as u32,
262            };
263
264            let _old_back = std::mem::replace(
265                back,
266                OwnedTexture::new(&self.common.device, size, 1, input.pixelFormat())?,
267            );
268        }
269
270        back.copy_from(&mipmapper, input)?;
271        mipmapper.endEncoding();
272        Ok(())
273    }
274
275    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
276    /// to the caller. This function therefore requires no external synchronization of the device queue.
277    ///
278    /// ## Safety
279    /// The provided command buffer must be ready for recording.
280    /// The caller is responsible for ending the command buffer and immediately submitting it to a
281    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
282    pub fn load_from_preset_deferred(
283        preset: ShaderPreset,
284        queue: &ProtocolObject<dyn MTLCommandQueue>,
285        cmd: &ProtocolObject<dyn MTLCommandBuffer>,
286        options: Option<&FilterChainOptionsMetal>,
287    ) -> error::Result<FilterChainMetal> {
288        let preset = ShaderPresetPack::load_from_preset::<FilterChainError>(preset)?;
289        Self::load_from_pack_deferred(preset, queue, cmd, options)
290    }
291
292    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
293    /// to the caller. This function therefore requires no external synchronization of the device queue.
294    ///
295    /// ## Safety
296    /// The provided command buffer must be ready for recording.
297    /// The caller is responsible for ending the command buffer and immediately submitting it to a
298    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
299    pub fn load_from_pack_deferred(
300        preset: ShaderPresetPack,
301        queue: &ProtocolObject<dyn MTLCommandQueue>,
302        cmd: &ProtocolObject<dyn MTLCommandBuffer>,
303        options: Option<&FilterChainOptionsMetal>,
304    ) -> error::Result<FilterChainMetal> {
305        Self::load_from_pack_deferred_internal(preset, queue.device(), &cmd, options)
306    }
307
308    /// Load a filter chain from a pre-parsed `ShaderPreset`, deferring and GPU-side initialization
309    /// to the caller. This function therefore requires no external synchronization of the device queue.
310    ///
311    /// ## Safety
312    /// The provided command buffer must be ready for recording.
313    /// The caller is responsible for ending the command buffer and immediately submitting it to a
314    /// graphics queue. The command buffer must be completely executed before calling [`frame`](Self::frame).
315    fn load_from_pack_deferred_internal(
316        preset: ShaderPresetPack,
317        device: Retained<ProtocolObject<dyn MTLDevice>>,
318        cmd: &ProtocolObject<dyn MTLCommandBuffer>,
319        options: Option<&FilterChainOptionsMetal>,
320    ) -> error::Result<FilterChainMetal> {
321        let config = RuntimeParameters::new(&preset);
322        let (passes, semantics) = compile_passes(preset.passes, &preset.textures)?;
323
324        let filters = Self::init_passes(&device, passes, &semantics)?;
325
326        let samplers = SamplerSet::new(&device)?;
327        let luts = FilterChainMetal::load_luts(&device, &cmd, preset.textures)?;
328        let framebuffer_gen = || {
329            Ok::<_, error::FilterChainError>(OwnedTexture::new(
330                &device,
331                Size::new(1, 1),
332                1,
333                ImageFormat::R8G8B8A8Unorm.into(),
334            )?)
335        };
336        let input_gen = || None;
337        let framebuffer_init = FramebufferInit::new(
338            filters.iter().map(|f| &f.reflection.meta),
339            &framebuffer_gen,
340            &input_gen,
341        );
342        let (output_framebuffers, output_textures) = framebuffer_init.init_output_framebuffers()?;
343        //
344        // initialize feedback framebuffers
345        let (feedback_framebuffers, feedback_textures) =
346            framebuffer_init.init_output_framebuffers()?;
347        //
348        // initialize history
349        let (history_framebuffers, history_textures) = framebuffer_init.init_history()?;
350
351        let history_buffer = framebuffer_gen()?;
352
353        let draw_quad = DrawQuad::new(&device)?;
354        Ok(FilterChainMetal {
355            draw_last_pass_feedback: framebuffer_init.uses_final_pass_as_feedback(),
356            common: FilterCommon {
357                luts,
358                samplers,
359                config,
360                draw_quad,
361                device,
362                output_textures,
363                feedback_textures,
364                history_textures,
365            },
366            passes: filters,
367            output_framebuffers,
368            feedback_framebuffers,
369            history_framebuffers,
370            prev_frame_history_buffer: history_buffer,
371            disable_mipmaps: options.map(|f| f.force_no_mipmaps).unwrap_or(false),
372            default_options: Default::default(),
373        })
374    }
375
376    /// Records shader rendering commands to the provided command encoder.
377    ///
378    /// SAFETY: The `MTLCommandBuffer` provided must not have an active encoder.
379    pub fn frame(
380        &mut self,
381        input: &ProtocolObject<dyn MTLTexture>,
382        viewport: &Viewport<MetalTextureRef>,
383        cmd: &ProtocolObject<dyn MTLCommandBuffer>,
384        frame_count: usize,
385        options: Option<&FrameOptionsMetal>,
386    ) -> error::Result<()> {
387        let max = std::cmp::min(self.passes.len(), self.common.config.passes_enabled());
388        if let Some(options) = &options {
389            let clear_desc = unsafe { MTLRenderPassDescriptor::new() };
390            if options.clear_history {
391                for (index, history) in self.history_framebuffers.iter().enumerate() {
392                    unsafe {
393                        let ca = clear_desc
394                            .colorAttachments()
395                            .objectAtIndexedSubscript(index);
396                        ca.setTexture(Some(&history.texture));
397                        ca.setLoadAction(MTLLoadAction::Clear);
398                        ca.setStoreAction(MTLStoreAction::Store);
399                    }
400                }
401
402                let clearpass = cmd
403                    .renderCommandEncoderWithDescriptor(&clear_desc)
404                    .ok_or(FilterChainError::FailedToCreateCommandBuffer)?;
405                clearpass.endEncoding();
406            }
407        }
408
409        self.push_history(&cmd, &input)?;
410
411        let passes = &mut self.passes[0..max];
412        if passes.is_empty() {
413            return Ok(());
414        }
415
416        let filter = passes[0].meta.filter;
417        let wrap_mode = passes[0].meta.wrap_mode;
418
419        // update history
420        for (texture, image) in self
421            .common
422            .history_textures
423            .iter_mut()
424            .zip(self.history_framebuffers.iter())
425        {
426            *texture = Some(image.as_input(filter, wrap_mode)?);
427        }
428
429        let original = InputTexture {
430            texture: input
431                .newTextureViewWithPixelFormat(input.pixelFormat())
432                .ok_or(FilterChainError::FailedToCreateTexture)?,
433            wrap_mode,
434            filter_mode: filter,
435            mip_filter: filter,
436        };
437
438        let mut source = original.try_clone()?;
439
440        source
441            .texture
442            .setLabel(Some(&*NSString::from_str("librashader_sourcetex")));
443
444        // swap output and feedback **before** recording command buffers
445        std::mem::swap(
446            &mut self.output_framebuffers,
447            &mut self.feedback_framebuffers,
448        );
449
450        // rescale render buffers to ensure all bindings are valid.
451        OwnedTexture::scale_framebuffers_with_context(
452            get_texture_size(&source.texture).into(),
453            get_texture_size(viewport.output),
454            get_texture_size(&original.texture).into(),
455            &mut self.output_framebuffers,
456            &mut self.feedback_framebuffers,
457            passes,
458            &self.common.device,
459            Some(&mut |index: usize,
460                       pass: &FilterPass,
461                       output: &OwnedTexture,
462                       feedback: &OwnedTexture| {
463                // refresh inputs
464                self.common.feedback_textures[index] =
465                    Some(feedback.as_input(pass.meta.filter, pass.meta.wrap_mode)?);
466                self.common.output_textures[index] =
467                    Some(output.as_input(pass.meta.filter, pass.meta.wrap_mode)?);
468                Ok(())
469            }),
470        )?;
471
472        let passes_len = passes.len();
473        let (pass, last) = passes.split_at_mut(passes_len - 1);
474        let options = options.unwrap_or(&self.default_options);
475
476        for (index, pass) in pass.iter_mut().enumerate() {
477            let target = &self.output_framebuffers[index];
478            source.filter_mode = pass.meta.filter;
479            source.wrap_mode = pass.meta.wrap_mode;
480            source.mip_filter = pass.meta.filter;
481
482            let out = RenderTarget::identity(target.texture.as_ref())?;
483            pass.draw(
484                &cmd,
485                index,
486                &self.common,
487                pass.meta.get_frame_count(frame_count),
488                options,
489                viewport,
490                &original,
491                &source,
492                &out,
493                QuadType::Offscreen,
494            )?;
495
496            if target.max_miplevels > 1 && !self.disable_mipmaps {
497                target.generate_mipmaps(&cmd)?;
498            }
499
500            self.common.output_textures[index] =
501                Some(target.as_input(pass.meta.filter, pass.meta.wrap_mode)?);
502            source = self.common.output_textures[index]
503                .as_ref()
504                .map(InputTexture::try_clone)
505                .unwrap()?;
506        }
507
508        // try to hint the optimizer
509        assert_eq!(last.len(), 1);
510
511        if let Some(pass) = last.iter_mut().next() {
512            if !pass
513                .graphics_pipeline
514                .has_format(viewport.output.pixelFormat())
515            {
516                // need to recompile
517                pass.graphics_pipeline
518                    .recompile(&self.common.device, viewport.output.pixelFormat())?;
519            }
520
521            source.filter_mode = pass.meta.filter;
522            source.wrap_mode = pass.meta.wrap_mode;
523            source.mip_filter = pass.meta.filter;
524            let index = passes_len - 1;
525
526            if self.draw_last_pass_feedback {
527                let output_image = &self.output_framebuffers[index].texture;
528                let out = RenderTarget::viewport_with_output(output_image.as_ref(), viewport);
529                pass.draw(
530                    &cmd,
531                    passes_len - 1,
532                    &self.common,
533                    pass.meta.get_frame_count(frame_count),
534                    options,
535                    viewport,
536                    &original,
537                    &source,
538                    &out,
539                    QuadType::Final,
540                )?;
541            }
542
543            let out = RenderTarget::viewport(viewport);
544            pass.draw(
545                &cmd,
546                index,
547                &self.common,
548                pass.meta.get_frame_count(frame_count),
549                options,
550                viewport,
551                &original,
552                &source,
553                &out,
554                QuadType::Final,
555            )?;
556        }
557
558        Ok(())
559    }
560}