Skip to main content

apple_mps/
filters.rs

1use crate::ffi;
2use crate::image::{Image, ImageRegion};
3use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice, MetalTexture};
4use core::ffi::c_void;
5use core::ptr;
6
7macro_rules! opaque_handle {
8    ($name:ident, $doc:expr) => {
9        #[doc = $doc]
10        pub struct $name {
11            ptr: *mut c_void,
12        }
13
14        // SAFETY: MPS filter handles are opaque pointers to thread-safe Swift/ObjC objects.
15        unsafe impl Send for $name {}
16        // SAFETY: MPS filter handles are opaque pointers to thread-safe Swift/ObjC objects.
17        unsafe impl Sync for $name {}
18
19        impl Drop for $name {
20            fn drop(&mut self) {
21                if !self.ptr.is_null() {
22                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
23                    unsafe { ffi::mps_object_release(self.ptr) };
24                    self.ptr = ptr::null_mut();
25                }
26            }
27        }
28
29        impl $name {
30            /// Returns the retained Objective-C pointer backing this wrapper.
31            #[must_use]
32            pub const fn as_ptr(&self) -> *mut c_void {
33                self.ptr
34            }
35        }
36    };
37}
38
39macro_rules! impl_unary_methods {
40    ($name:ident) => {
41        impl $name {
42            /// Encode the filter against `MPSImage` inputs/outputs.
43            pub fn encode_image(
44                &self,
45                command_buffer: &CommandBuffer,
46                source: &Image,
47                destination: &Image,
48            ) {
49                // SAFETY: All handles come from safe wrappers and remain alive for the call.
50                unsafe {
51                    ffi::mps_unary_encode_image(
52                        self.ptr,
53                        command_buffer.as_ptr(),
54                        source.as_ptr(),
55                        destination.as_ptr(),
56                    )
57                };
58            }
59
60            /// Encode the filter directly against `MTLTexture` inputs/outputs.
61            pub fn encode_texture(
62                &self,
63                command_buffer: &CommandBuffer,
64                source: &MetalTexture,
65                destination: &MetalTexture,
66            ) {
67                // SAFETY: All handles come from safe wrappers and remain alive for the call.
68                unsafe {
69                    ffi::mps_unary_encode_texture(
70                        self.ptr,
71                        command_buffer.as_ptr(),
72                        source.as_ptr(),
73                        destination.as_ptr(),
74                    )
75                };
76            }
77
78            /// Configure the kernel's edge mode.
79            pub fn set_edge_mode(&self, edge_mode: usize) {
80                // SAFETY: The kernel pointer is valid for the duration of the call.
81                unsafe { ffi::mps_unary_set_edge_mode(self.ptr, edge_mode) };
82            }
83
84            /// Restrict writes to a destination clip rectangle.
85            pub fn set_clip_rect(&self, region: ImageRegion) {
86                // SAFETY: The kernel pointer is valid for the duration of the call.
87                unsafe {
88                    ffi::mps_unary_set_clip_rect(
89                        self.ptr,
90                        region.x,
91                        region.y,
92                        region.z,
93                        region.width,
94                        region.height,
95                        region.depth,
96                    )
97                };
98            }
99        }
100    };
101}
102
103macro_rules! impl_binary_methods {
104    ($name:ident) => {
105        impl $name {
106            /// Encode the filter against `MPSImage` inputs/outputs.
107            pub fn encode_image(
108                &self,
109                command_buffer: &CommandBuffer,
110                primary: &Image,
111                secondary: &Image,
112                destination: &Image,
113            ) {
114                // SAFETY: All handles come from safe wrappers and remain alive for the call.
115                unsafe {
116                    ffi::mps_binary_encode_image(
117                        self.ptr,
118                        command_buffer.as_ptr(),
119                        primary.as_ptr(),
120                        secondary.as_ptr(),
121                        destination.as_ptr(),
122                    )
123                };
124            }
125
126            /// Encode the filter directly against `MTLTexture` inputs/outputs.
127            pub fn encode_texture(
128                &self,
129                command_buffer: &CommandBuffer,
130                primary: &MetalTexture,
131                secondary: &MetalTexture,
132                destination: &MetalTexture,
133            ) {
134                // SAFETY: All handles come from safe wrappers and remain alive for the call.
135                unsafe {
136                    ffi::mps_binary_encode_texture(
137                        self.ptr,
138                        command_buffer.as_ptr(),
139                        primary.as_ptr(),
140                        secondary.as_ptr(),
141                        destination.as_ptr(),
142                    )
143                };
144            }
145
146            /// Configure the primary input edge mode.
147            pub fn set_primary_edge_mode(&self, edge_mode: usize) {
148                // SAFETY: The kernel pointer is valid for the duration of the call.
149                unsafe { ffi::mps_binary_set_primary_edge_mode(self.ptr, edge_mode) };
150            }
151
152            /// Configure the secondary input edge mode.
153            pub fn set_secondary_edge_mode(&self, edge_mode: usize) {
154                // SAFETY: The kernel pointer is valid for the duration of the call.
155                unsafe { ffi::mps_binary_set_secondary_edge_mode(self.ptr, edge_mode) };
156            }
157
158            /// Restrict writes to a destination clip rectangle.
159            pub fn set_clip_rect(&self, region: ImageRegion) {
160                // SAFETY: The kernel pointer is valid for the duration of the call.
161                unsafe {
162                    ffi::mps_binary_set_clip_rect(
163                        self.ptr,
164                        region.x,
165                        region.y,
166                        region.z,
167                        region.width,
168                        region.height,
169                        region.depth,
170                    )
171                };
172            }
173        }
174    };
175}
176
177/// `MPSScaleTransform` values used by resampling kernels.
178#[derive(Debug, Clone, Copy)]
179pub struct ScaleTransform {
180    /// Corresponds to the `scale_x` field on `MPSScaleTransform`.
181    pub scale_x: f64,
182    /// Corresponds to the `scale_y` field on `MPSScaleTransform`.
183    pub scale_y: f64,
184    /// Corresponds to the `translate_x` field on `MPSScaleTransform`.
185    pub translate_x: f64,
186    /// Corresponds to the `translate_y` field on `MPSScaleTransform`.
187    pub translate_y: f64,
188}
189
190/// Plain-Rust configuration for `MPSImageHistogramInfo`.
191#[derive(Debug, Clone, Copy)]
192pub struct HistogramInfo {
193    /// Corresponds to the `number_of_entries` field on `MPSImageHistogramInfo`.
194    pub number_of_entries: usize,
195    /// Corresponds to the `histogram_for_alpha` field on `MPSImageHistogramInfo`.
196    pub histogram_for_alpha: bool,
197    /// Corresponds to the `min_pixel_value` field on `MPSImageHistogramInfo`.
198    pub min_pixel_value: [f32; 4],
199    /// Corresponds to the `max_pixel_value` field on `MPSImageHistogramInfo`.
200    pub max_pixel_value: [f32; 4],
201}
202
203opaque_handle!(ImageGaussianBlur, "Wraps `MPSImageGaussianBlur`.");
204impl ImageGaussianBlur {
205    /// Wraps a constructor on `MPSImageGaussianBlur`.
206    #[must_use]
207    pub fn new(device: &MetalDevice, sigma: f32) -> Option<Self> {
208        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
209        let ptr = unsafe { ffi::mps_image_gaussian_blur_new(device.as_ptr(), sigma) };
210        if ptr.is_null() {
211            None
212        } else {
213            Some(Self { ptr })
214        }
215    }
216}
217impl_unary_methods!(ImageGaussianBlur);
218
219opaque_handle!(ImageBox, "Wraps `MPSImageBox`.");
220impl ImageBox {
221    /// Wraps a constructor on `MPSImageBox`.
222    #[must_use]
223    pub fn new(device: &MetalDevice, kernel_width: usize, kernel_height: usize) -> Option<Self> {
224        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
225        let ptr = unsafe { ffi::mps_image_box_new(device.as_ptr(), kernel_width, kernel_height) };
226        if ptr.is_null() {
227            None
228        } else {
229            Some(Self { ptr })
230        }
231    }
232}
233impl_unary_methods!(ImageBox);
234
235opaque_handle!(ImageSobel, "Wraps `MPSImageSobel`.");
236impl ImageSobel {
237    /// Wraps a constructor on `MPSImageSobel`.
238    #[must_use]
239    pub fn new(device: &MetalDevice) -> Option<Self> {
240        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
241        let ptr = unsafe { ffi::mps_image_sobel_new(device.as_ptr(), core::ptr::null()) };
242        if ptr.is_null() {
243            None
244        } else {
245            Some(Self { ptr })
246        }
247    }
248
249    /// Wraps the corresponding `MPSImageSobel` method.
250    #[must_use]
251    pub fn with_transform(device: &MetalDevice, transform: [f32; 3]) -> Option<Self> {
252        // SAFETY: `transform` lives for the duration of the FFI call.
253        let ptr = unsafe { ffi::mps_image_sobel_new(device.as_ptr(), transform.as_ptr()) };
254        if ptr.is_null() {
255            None
256        } else {
257            Some(Self { ptr })
258        }
259    }
260}
261impl_unary_methods!(ImageSobel);
262
263opaque_handle!(ImageMedian, "Wraps `MPSImageMedian`.");
264impl ImageMedian {
265    /// Wraps a constructor on `MPSImageMedian`.
266    #[must_use]
267    pub fn new(device: &MetalDevice, kernel_diameter: usize) -> Option<Self> {
268        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
269        let ptr = unsafe { ffi::mps_image_median_new(device.as_ptr(), kernel_diameter) };
270        if ptr.is_null() {
271            None
272        } else {
273            Some(Self { ptr })
274        }
275    }
276}
277impl_unary_methods!(ImageMedian);
278
279opaque_handle!(ImageConvolution, "Wraps `MPSImageConvolution`.");
280impl ImageConvolution {
281    /// Wraps a constructor on `MPSImageConvolution`.
282    #[must_use]
283    pub fn new(
284        device: &MetalDevice,
285        kernel_width: usize,
286        kernel_height: usize,
287        weights: &[f32],
288    ) -> Option<Self> {
289        if weights.len() != kernel_width.saturating_mul(kernel_height) {
290            return None;
291        }
292
293        // SAFETY: `weights` lives for the duration of the FFI call.
294        let ptr = unsafe {
295            ffi::mps_image_convolution_new(
296                device.as_ptr(),
297                kernel_width,
298                kernel_height,
299                weights.as_ptr(),
300            )
301        };
302        if ptr.is_null() {
303            None
304        } else {
305            Some(Self { ptr })
306        }
307    }
308}
309impl_unary_methods!(ImageConvolution);
310
311opaque_handle!(ImageBilinearScale, "Wraps `MPSImageBilinearScale`.");
312impl ImageBilinearScale {
313    /// Wraps a constructor on `MPSImageBilinearScale`.
314    #[must_use]
315    pub fn new(device: &MetalDevice) -> Option<Self> {
316        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
317        let ptr = unsafe { ffi::mps_image_bilinear_scale_new(device.as_ptr()) };
318        if ptr.is_null() {
319            None
320        } else {
321            Some(Self { ptr })
322        }
323    }
324
325    /// Override the default fit-to-destination scale transform.
326    pub fn set_scale_transform(&self, transform: ScaleTransform) {
327        // SAFETY: The kernel pointer is valid for the duration of the call.
328        unsafe {
329            ffi::mps_image_scale_set_transform(
330                self.ptr,
331                transform.scale_x,
332                transform.scale_y,
333                transform.translate_x,
334                transform.translate_y,
335            );
336        };
337    }
338}
339impl_unary_methods!(ImageBilinearScale);
340
341opaque_handle!(ImageLanczosScale, "Wraps `MPSImageLanczosScale`.");
342impl ImageLanczosScale {
343    /// Wraps a constructor on `MPSImageLanczosScale`.
344    #[must_use]
345    pub fn new(device: &MetalDevice) -> Option<Self> {
346        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
347        let ptr = unsafe { ffi::mps_image_lanczos_scale_new(device.as_ptr()) };
348        if ptr.is_null() {
349            None
350        } else {
351            Some(Self { ptr })
352        }
353    }
354
355    /// Override the default fit-to-destination scale transform.
356    pub fn set_scale_transform(&self, transform: ScaleTransform) {
357        // SAFETY: The kernel pointer is valid for the duration of the call.
358        unsafe {
359            ffi::mps_image_scale_set_transform(
360                self.ptr,
361                transform.scale_x,
362                transform.scale_y,
363                transform.translate_x,
364                transform.translate_y,
365            );
366        };
367    }
368}
369impl_unary_methods!(ImageLanczosScale);
370
371opaque_handle!(ImageThresholdBinary, "Wraps `MPSImageThresholdBinary`.");
372impl ImageThresholdBinary {
373    /// Wraps a constructor on `MPSImageThresholdBinary`.
374    #[must_use]
375    pub fn new(device: &MetalDevice, threshold_value: f32, maximum_value: f32) -> Option<Self> {
376        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
377        let ptr = unsafe {
378            ffi::mps_image_threshold_binary_new(
379                device.as_ptr(),
380                threshold_value,
381                maximum_value,
382                core::ptr::null(),
383            )
384        };
385        if ptr.is_null() {
386            None
387        } else {
388            Some(Self { ptr })
389        }
390    }
391
392    /// Wraps the corresponding `MPSImageThresholdBinary` method.
393    #[must_use]
394    pub fn with_transform(
395        device: &MetalDevice,
396        threshold_value: f32,
397        maximum_value: f32,
398        transform: [f32; 3],
399    ) -> Option<Self> {
400        // SAFETY: `transform` lives for the duration of the FFI call.
401        let ptr = unsafe {
402            ffi::mps_image_threshold_binary_new(
403                device.as_ptr(),
404                threshold_value,
405                maximum_value,
406                transform.as_ptr(),
407            )
408        };
409        if ptr.is_null() {
410            None
411        } else {
412            Some(Self { ptr })
413        }
414    }
415}
416impl_unary_methods!(ImageThresholdBinary);
417
418opaque_handle!(ImageHistogram, "Wraps `MPSImageHistogram`.");
419impl ImageHistogram {
420    /// Wraps a constructor on `MPSImageHistogram`.
421    #[must_use]
422    pub fn new(device: &MetalDevice, info: HistogramInfo) -> Option<Self> {
423        // SAFETY: `info` arrays live for the duration of the FFI call.
424        let ptr = unsafe {
425            ffi::mps_image_histogram_new(
426                device.as_ptr(),
427                info.number_of_entries,
428                info.histogram_for_alpha,
429                info.min_pixel_value.as_ptr(),
430                info.max_pixel_value.as_ptr(),
431            )
432        };
433        if ptr.is_null() {
434            None
435        } else {
436            Some(Self { ptr })
437        }
438    }
439
440    /// Encode a histogram pass using an `MPSImage` source.
441    pub fn encode_image(
442        &self,
443        command_buffer: &CommandBuffer,
444        source: &Image,
445        histogram_buffer: &MetalBuffer,
446        histogram_offset: usize,
447    ) {
448        // SAFETY: All handles come from safe wrappers and remain alive for the call.
449        unsafe {
450            ffi::mps_image_histogram_encode_image(
451                self.ptr,
452                command_buffer.as_ptr(),
453                source.as_ptr(),
454                histogram_buffer.as_ptr(),
455                histogram_offset,
456            );
457        };
458    }
459
460    /// Encode a histogram pass using a raw `MTLTexture` source.
461    pub fn encode_texture(
462        &self,
463        command_buffer: &CommandBuffer,
464        source: &MetalTexture,
465        histogram_buffer: &MetalBuffer,
466        histogram_offset: usize,
467    ) {
468        // SAFETY: All handles come from safe wrappers and remain alive for the call.
469        unsafe {
470            ffi::mps_image_histogram_encode_texture(
471                self.ptr,
472                command_buffer.as_ptr(),
473                source.as_ptr(),
474                histogram_buffer.as_ptr(),
475                histogram_offset,
476            );
477        };
478    }
479
480    /// Report the minimum output buffer size for the given source pixel format.
481    #[must_use]
482    pub fn histogram_size_for_source_format(&self, source_format: usize) -> usize {
483        // SAFETY: The histogram pointer is valid for the duration of the call.
484        unsafe { ffi::mps_image_histogram_size_for_source_format(self.ptr, source_format) }
485    }
486}
487
488opaque_handle!(ImageStatisticsMinAndMax, "Wraps `MPSImageStatisticsMinAndMax`.");
489impl ImageStatisticsMinAndMax {
490    /// Wraps a constructor on `MPSImageStatisticsMinAndMax`.
491    #[must_use]
492    pub fn new(device: &MetalDevice) -> Option<Self> {
493        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
494        let ptr = unsafe { ffi::mps_image_statistics_min_max_new(device.as_ptr()) };
495        if ptr.is_null() {
496            None
497        } else {
498            Some(Self { ptr })
499        }
500    }
501}
502impl_unary_methods!(ImageStatisticsMinAndMax);
503
504opaque_handle!(ImageStatisticsMean, "Wraps `MPSImageStatisticsMean`.");
505impl ImageStatisticsMean {
506    /// Wraps a constructor on `MPSImageStatisticsMean`.
507    #[must_use]
508    pub fn new(device: &MetalDevice) -> Option<Self> {
509        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
510        let ptr = unsafe { ffi::mps_image_statistics_mean_new(device.as_ptr()) };
511        if ptr.is_null() {
512            None
513        } else {
514            Some(Self { ptr })
515        }
516    }
517}
518impl_unary_methods!(ImageStatisticsMean);
519
520opaque_handle!(ImageReduceRowMin, "Wraps `MPSImageReduceRowMin`.");
521impl ImageReduceRowMin {
522    /// Wraps a constructor on `MPSImageReduceRowMin`.
523    #[must_use]
524    pub fn new(device: &MetalDevice) -> Option<Self> {
525        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
526        let ptr = unsafe { ffi::mps_image_reduce_row_min_new(device.as_ptr()) };
527        if ptr.is_null() {
528            None
529        } else {
530            Some(Self { ptr })
531        }
532    }
533}
534impl_unary_methods!(ImageReduceRowMin);
535
536opaque_handle!(ImageReduceRowMax, "Wraps `MPSImageReduceRowMax`.");
537impl ImageReduceRowMax {
538    /// Wraps a constructor on `MPSImageReduceRowMax`.
539    #[must_use]
540    pub fn new(device: &MetalDevice) -> Option<Self> {
541        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
542        let ptr = unsafe { ffi::mps_image_reduce_row_max_new(device.as_ptr()) };
543        if ptr.is_null() {
544            None
545        } else {
546            Some(Self { ptr })
547        }
548    }
549}
550impl_unary_methods!(ImageReduceRowMax);
551
552opaque_handle!(ImageReduceRowMean, "Wraps `MPSImageReduceRowMean`.");
553impl ImageReduceRowMean {
554    /// Wraps a constructor on `MPSImageReduceRowMean`.
555    #[must_use]
556    pub fn new(device: &MetalDevice) -> Option<Self> {
557        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
558        let ptr = unsafe { ffi::mps_image_reduce_row_mean_new(device.as_ptr()) };
559        if ptr.is_null() {
560            None
561        } else {
562            Some(Self { ptr })
563        }
564    }
565}
566impl_unary_methods!(ImageReduceRowMean);
567
568opaque_handle!(ImageReduceRowSum, "Wraps `MPSImageReduceRowSum`.");
569impl ImageReduceRowSum {
570    /// Wraps a constructor on `MPSImageReduceRowSum`.
571    #[must_use]
572    pub fn new(device: &MetalDevice) -> Option<Self> {
573        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
574        let ptr = unsafe { ffi::mps_image_reduce_row_sum_new(device.as_ptr()) };
575        if ptr.is_null() {
576            None
577        } else {
578            Some(Self { ptr })
579        }
580    }
581}
582impl_unary_methods!(ImageReduceRowSum);
583
584opaque_handle!(ImageAdd, "Wraps `MPSImageAdd`.");
585impl ImageAdd {
586    /// Wraps a constructor on `MPSImageAdd`.
587    #[must_use]
588    pub fn new(device: &MetalDevice) -> Option<Self> {
589        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
590        let ptr = unsafe { ffi::mps_image_add_new(device.as_ptr()) };
591        if ptr.is_null() {
592            None
593        } else {
594            Some(Self { ptr })
595        }
596    }
597
598    /// Set `primaryScale`, `secondaryScale`, and `bias` in one call.
599    pub fn set_scales(&self, primary_scale: f32, secondary_scale: f32, bias: f32) {
600        // SAFETY: The kernel pointer is valid for the duration of the call.
601        unsafe {
602            ffi::mps_image_arithmetic_set_scales_bias(
603                self.ptr,
604                primary_scale,
605                secondary_scale,
606                bias,
607            );
608        };
609    }
610
611    /// Clamp arithmetic results to the closed interval `[minimum_value, maximum_value]`.
612    pub fn set_clamp(&self, minimum_value: f32, maximum_value: f32) {
613        // SAFETY: The kernel pointer is valid for the duration of the call.
614        unsafe { ffi::mps_image_arithmetic_set_clamp(self.ptr, minimum_value, maximum_value) };
615    }
616}
617impl_binary_methods!(ImageAdd);
618
619/// Convenience wrapper for `scale-and-add` semantics implemented with `MPSImageAdd`.
620pub struct ImageScaleAndAdd {
621    inner: ImageAdd,
622}
623
624impl ImageScaleAndAdd {
625    /// Build an image add kernel with non-unit primary/secondary scales.
626    #[must_use]
627    pub fn new(
628        device: &MetalDevice,
629        primary_scale: f32,
630        secondary_scale: f32,
631        bias: f32,
632    ) -> Option<Self> {
633        let inner = ImageAdd::new(device)?;
634        inner.set_scales(primary_scale, secondary_scale, bias);
635        Some(Self { inner })
636    }
637
638    /// Wraps a Metal Performance Shaders raw value.
639    #[must_use]
640    pub const fn as_ptr(&self) -> *mut c_void {
641        self.inner.as_ptr()
642    }
643
644    /// Wraps the corresponding `MPSImageAdd` encode entry point.
645    pub fn encode_image(
646        &self,
647        command_buffer: &CommandBuffer,
648        primary: &Image,
649        secondary: &Image,
650        destination: &Image,
651    ) {
652        self.inner
653            .encode_image(command_buffer, primary, secondary, destination);
654    }
655
656    /// Wraps the corresponding `MPSImageAdd` encode entry point.
657    pub fn encode_texture(
658        &self,
659        command_buffer: &CommandBuffer,
660        primary: &MetalTexture,
661        secondary: &MetalTexture,
662        destination: &MetalTexture,
663    ) {
664        self.inner
665            .encode_texture(command_buffer, primary, secondary, destination);
666    }
667
668    /// Wraps the corresponding `MPSImageAdd` setter.
669    pub fn set_primary_edge_mode(&self, edge_mode: usize) {
670        self.inner.set_primary_edge_mode(edge_mode);
671    }
672
673    /// Wraps the corresponding `MPSImageAdd` setter.
674    pub fn set_secondary_edge_mode(&self, edge_mode: usize) {
675        self.inner.set_secondary_edge_mode(edge_mode);
676    }
677
678    /// Wraps the corresponding `MPSImageAdd` setter.
679    pub fn set_clip_rect(&self, region: ImageRegion) {
680        self.inner.set_clip_rect(region);
681    }
682
683    /// Wraps the corresponding `MPSImageAdd` setter.
684    pub fn set_clamp(&self, minimum_value: f32, maximum_value: f32) {
685        self.inner.set_clamp(minimum_value, maximum_value);
686    }
687}