Skip to main content

apple_mps/
filters.rs

1use crate::ffi;
2use crate::image::{Image, ImageRegion};
3use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice, MetalTexture};
4use core::ffi::c_void;
5use core::ptr;
6
7macro_rules! opaque_handle {
8    ($name:ident) => {
9        pub struct $name {
10            ptr: *mut c_void,
11        }
12
13        unsafe impl Send for $name {}
14        unsafe impl Sync for $name {}
15
16        impl Drop for $name {
17            fn drop(&mut self) {
18                if !self.ptr.is_null() {
19                    // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
20                    unsafe { ffi::mps_object_release(self.ptr) };
21                    self.ptr = ptr::null_mut();
22                }
23            }
24        }
25
26        impl $name {
27            #[must_use]
28            pub const fn as_ptr(&self) -> *mut c_void {
29                self.ptr
30            }
31        }
32    };
33}
34
35macro_rules! impl_unary_methods {
36    ($name:ident) => {
37        impl $name {
38            /// Encode the filter against `MPSImage` inputs/outputs.
39            pub fn encode_image(
40                &self,
41                command_buffer: &CommandBuffer,
42                source: &Image,
43                destination: &Image,
44            ) {
45                // SAFETY: All handles come from safe wrappers and remain alive for the call.
46                unsafe {
47                    ffi::mps_unary_encode_image(
48                        self.ptr,
49                        command_buffer.as_ptr(),
50                        source.as_ptr(),
51                        destination.as_ptr(),
52                    )
53                };
54            }
55
56            /// Encode the filter directly against `MTLTexture` inputs/outputs.
57            pub fn encode_texture(
58                &self,
59                command_buffer: &CommandBuffer,
60                source: &MetalTexture,
61                destination: &MetalTexture,
62            ) {
63                // SAFETY: All handles come from safe wrappers and remain alive for the call.
64                unsafe {
65                    ffi::mps_unary_encode_texture(
66                        self.ptr,
67                        command_buffer.as_ptr(),
68                        source.as_ptr(),
69                        destination.as_ptr(),
70                    )
71                };
72            }
73
74            /// Configure the kernel's edge mode.
75            pub fn set_edge_mode(&self, edge_mode: usize) {
76                // SAFETY: The kernel pointer is valid for the duration of the call.
77                unsafe { ffi::mps_unary_set_edge_mode(self.ptr, edge_mode) };
78            }
79
80            /// Restrict writes to a destination clip rectangle.
81            pub fn set_clip_rect(&self, region: ImageRegion) {
82                // SAFETY: The kernel pointer is valid for the duration of the call.
83                unsafe {
84                    ffi::mps_unary_set_clip_rect(
85                        self.ptr,
86                        region.x,
87                        region.y,
88                        region.z,
89                        region.width,
90                        region.height,
91                        region.depth,
92                    )
93                };
94            }
95        }
96    };
97}
98
99macro_rules! impl_binary_methods {
100    ($name:ident) => {
101        impl $name {
102            /// Encode the filter against `MPSImage` inputs/outputs.
103            pub fn encode_image(
104                &self,
105                command_buffer: &CommandBuffer,
106                primary: &Image,
107                secondary: &Image,
108                destination: &Image,
109            ) {
110                // SAFETY: All handles come from safe wrappers and remain alive for the call.
111                unsafe {
112                    ffi::mps_binary_encode_image(
113                        self.ptr,
114                        command_buffer.as_ptr(),
115                        primary.as_ptr(),
116                        secondary.as_ptr(),
117                        destination.as_ptr(),
118                    )
119                };
120            }
121
122            /// Encode the filter directly against `MTLTexture` inputs/outputs.
123            pub fn encode_texture(
124                &self,
125                command_buffer: &CommandBuffer,
126                primary: &MetalTexture,
127                secondary: &MetalTexture,
128                destination: &MetalTexture,
129            ) {
130                // SAFETY: All handles come from safe wrappers and remain alive for the call.
131                unsafe {
132                    ffi::mps_binary_encode_texture(
133                        self.ptr,
134                        command_buffer.as_ptr(),
135                        primary.as_ptr(),
136                        secondary.as_ptr(),
137                        destination.as_ptr(),
138                    )
139                };
140            }
141
142            /// Configure the primary input edge mode.
143            pub fn set_primary_edge_mode(&self, edge_mode: usize) {
144                // SAFETY: The kernel pointer is valid for the duration of the call.
145                unsafe { ffi::mps_binary_set_primary_edge_mode(self.ptr, edge_mode) };
146            }
147
148            /// Configure the secondary input edge mode.
149            pub fn set_secondary_edge_mode(&self, edge_mode: usize) {
150                // SAFETY: The kernel pointer is valid for the duration of the call.
151                unsafe { ffi::mps_binary_set_secondary_edge_mode(self.ptr, edge_mode) };
152            }
153
154            /// Restrict writes to a destination clip rectangle.
155            pub fn set_clip_rect(&self, region: ImageRegion) {
156                // SAFETY: The kernel pointer is valid for the duration of the call.
157                unsafe {
158                    ffi::mps_binary_set_clip_rect(
159                        self.ptr,
160                        region.x,
161                        region.y,
162                        region.z,
163                        region.width,
164                        region.height,
165                        region.depth,
166                    )
167                };
168            }
169        }
170    };
171}
172
173/// `MPSScaleTransform` values used by resampling kernels.
174#[derive(Debug, Clone, Copy)]
175pub struct ScaleTransform {
176    pub scale_x: f64,
177    pub scale_y: f64,
178    pub translate_x: f64,
179    pub translate_y: f64,
180}
181
182/// Plain-Rust configuration for `MPSImageHistogramInfo`.
183#[derive(Debug, Clone, Copy)]
184pub struct HistogramInfo {
185    pub number_of_entries: usize,
186    pub histogram_for_alpha: bool,
187    pub min_pixel_value: [f32; 4],
188    pub max_pixel_value: [f32; 4],
189}
190
191opaque_handle!(ImageGaussianBlur);
192impl ImageGaussianBlur {
193    #[must_use]
194    pub fn new(device: &MetalDevice, sigma: f32) -> Option<Self> {
195        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
196        let ptr = unsafe { ffi::mps_image_gaussian_blur_new(device.as_ptr(), sigma) };
197        if ptr.is_null() {
198            None
199        } else {
200            Some(Self { ptr })
201        }
202    }
203}
204impl_unary_methods!(ImageGaussianBlur);
205
206opaque_handle!(ImageBox);
207impl ImageBox {
208    #[must_use]
209    pub fn new(device: &MetalDevice, kernel_width: usize, kernel_height: usize) -> Option<Self> {
210        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
211        let ptr = unsafe { ffi::mps_image_box_new(device.as_ptr(), kernel_width, kernel_height) };
212        if ptr.is_null() {
213            None
214        } else {
215            Some(Self { ptr })
216        }
217    }
218}
219impl_unary_methods!(ImageBox);
220
221opaque_handle!(ImageSobel);
222impl ImageSobel {
223    #[must_use]
224    pub fn new(device: &MetalDevice) -> Option<Self> {
225        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
226        let ptr = unsafe { ffi::mps_image_sobel_new(device.as_ptr(), core::ptr::null()) };
227        if ptr.is_null() {
228            None
229        } else {
230            Some(Self { ptr })
231        }
232    }
233
234    #[must_use]
235    pub fn with_transform(device: &MetalDevice, transform: [f32; 3]) -> Option<Self> {
236        // SAFETY: `transform` lives for the duration of the FFI call.
237        let ptr = unsafe { ffi::mps_image_sobel_new(device.as_ptr(), transform.as_ptr()) };
238        if ptr.is_null() {
239            None
240        } else {
241            Some(Self { ptr })
242        }
243    }
244}
245impl_unary_methods!(ImageSobel);
246
247opaque_handle!(ImageMedian);
248impl ImageMedian {
249    #[must_use]
250    pub fn new(device: &MetalDevice, kernel_diameter: usize) -> Option<Self> {
251        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
252        let ptr = unsafe { ffi::mps_image_median_new(device.as_ptr(), kernel_diameter) };
253        if ptr.is_null() {
254            None
255        } else {
256            Some(Self { ptr })
257        }
258    }
259}
260impl_unary_methods!(ImageMedian);
261
262opaque_handle!(ImageConvolution);
263impl ImageConvolution {
264    #[must_use]
265    pub fn new(
266        device: &MetalDevice,
267        kernel_width: usize,
268        kernel_height: usize,
269        weights: &[f32],
270    ) -> Option<Self> {
271        if weights.len() != kernel_width.saturating_mul(kernel_height) {
272            return None;
273        }
274
275        // SAFETY: `weights` lives for the duration of the FFI call.
276        let ptr = unsafe {
277            ffi::mps_image_convolution_new(
278                device.as_ptr(),
279                kernel_width,
280                kernel_height,
281                weights.as_ptr(),
282            )
283        };
284        if ptr.is_null() {
285            None
286        } else {
287            Some(Self { ptr })
288        }
289    }
290}
291impl_unary_methods!(ImageConvolution);
292
293opaque_handle!(ImageBilinearScale);
294impl ImageBilinearScale {
295    #[must_use]
296    pub fn new(device: &MetalDevice) -> Option<Self> {
297        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
298        let ptr = unsafe { ffi::mps_image_bilinear_scale_new(device.as_ptr()) };
299        if ptr.is_null() {
300            None
301        } else {
302            Some(Self { ptr })
303        }
304    }
305
306    /// Override the default fit-to-destination scale transform.
307    pub fn set_scale_transform(&self, transform: ScaleTransform) {
308        // SAFETY: The kernel pointer is valid for the duration of the call.
309        unsafe {
310            ffi::mps_image_scale_set_transform(
311                self.ptr,
312                transform.scale_x,
313                transform.scale_y,
314                transform.translate_x,
315                transform.translate_y,
316            );
317        };
318    }
319}
320impl_unary_methods!(ImageBilinearScale);
321
322opaque_handle!(ImageLanczosScale);
323impl ImageLanczosScale {
324    #[must_use]
325    pub fn new(device: &MetalDevice) -> Option<Self> {
326        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
327        let ptr = unsafe { ffi::mps_image_lanczos_scale_new(device.as_ptr()) };
328        if ptr.is_null() {
329            None
330        } else {
331            Some(Self { ptr })
332        }
333    }
334
335    /// Override the default fit-to-destination scale transform.
336    pub fn set_scale_transform(&self, transform: ScaleTransform) {
337        // SAFETY: The kernel pointer is valid for the duration of the call.
338        unsafe {
339            ffi::mps_image_scale_set_transform(
340                self.ptr,
341                transform.scale_x,
342                transform.scale_y,
343                transform.translate_x,
344                transform.translate_y,
345            );
346        };
347    }
348}
349impl_unary_methods!(ImageLanczosScale);
350
351opaque_handle!(ImageThresholdBinary);
352impl ImageThresholdBinary {
353    #[must_use]
354    pub fn new(device: &MetalDevice, threshold_value: f32, maximum_value: f32) -> Option<Self> {
355        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
356        let ptr = unsafe {
357            ffi::mps_image_threshold_binary_new(
358                device.as_ptr(),
359                threshold_value,
360                maximum_value,
361                core::ptr::null(),
362            )
363        };
364        if ptr.is_null() {
365            None
366        } else {
367            Some(Self { ptr })
368        }
369    }
370
371    #[must_use]
372    pub fn with_transform(
373        device: &MetalDevice,
374        threshold_value: f32,
375        maximum_value: f32,
376        transform: [f32; 3],
377    ) -> Option<Self> {
378        // SAFETY: `transform` lives for the duration of the FFI call.
379        let ptr = unsafe {
380            ffi::mps_image_threshold_binary_new(
381                device.as_ptr(),
382                threshold_value,
383                maximum_value,
384                transform.as_ptr(),
385            )
386        };
387        if ptr.is_null() {
388            None
389        } else {
390            Some(Self { ptr })
391        }
392    }
393}
394impl_unary_methods!(ImageThresholdBinary);
395
396opaque_handle!(ImageHistogram);
397impl ImageHistogram {
398    #[must_use]
399    pub fn new(device: &MetalDevice, info: HistogramInfo) -> Option<Self> {
400        // SAFETY: `info` arrays live for the duration of the FFI call.
401        let ptr = unsafe {
402            ffi::mps_image_histogram_new(
403                device.as_ptr(),
404                info.number_of_entries,
405                info.histogram_for_alpha,
406                info.min_pixel_value.as_ptr(),
407                info.max_pixel_value.as_ptr(),
408            )
409        };
410        if ptr.is_null() {
411            None
412        } else {
413            Some(Self { ptr })
414        }
415    }
416
417    /// Encode a histogram pass using an `MPSImage` source.
418    pub fn encode_image(
419        &self,
420        command_buffer: &CommandBuffer,
421        source: &Image,
422        histogram_buffer: &MetalBuffer,
423        histogram_offset: usize,
424    ) {
425        // SAFETY: All handles come from safe wrappers and remain alive for the call.
426        unsafe {
427            ffi::mps_image_histogram_encode_image(
428                self.ptr,
429                command_buffer.as_ptr(),
430                source.as_ptr(),
431                histogram_buffer.as_ptr(),
432                histogram_offset,
433            );
434        };
435    }
436
437    /// Encode a histogram pass using a raw `MTLTexture` source.
438    pub fn encode_texture(
439        &self,
440        command_buffer: &CommandBuffer,
441        source: &MetalTexture,
442        histogram_buffer: &MetalBuffer,
443        histogram_offset: usize,
444    ) {
445        // SAFETY: All handles come from safe wrappers and remain alive for the call.
446        unsafe {
447            ffi::mps_image_histogram_encode_texture(
448                self.ptr,
449                command_buffer.as_ptr(),
450                source.as_ptr(),
451                histogram_buffer.as_ptr(),
452                histogram_offset,
453            );
454        };
455    }
456
457    /// Report the minimum output buffer size for the given source pixel format.
458    #[must_use]
459    pub fn histogram_size_for_source_format(&self, source_format: usize) -> usize {
460        // SAFETY: The histogram pointer is valid for the duration of the call.
461        unsafe { ffi::mps_image_histogram_size_for_source_format(self.ptr, source_format) }
462    }
463}
464
465opaque_handle!(ImageStatisticsMinAndMax);
466impl ImageStatisticsMinAndMax {
467    #[must_use]
468    pub fn new(device: &MetalDevice) -> Option<Self> {
469        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
470        let ptr = unsafe { ffi::mps_image_statistics_min_max_new(device.as_ptr()) };
471        if ptr.is_null() {
472            None
473        } else {
474            Some(Self { ptr })
475        }
476    }
477}
478impl_unary_methods!(ImageStatisticsMinAndMax);
479
480opaque_handle!(ImageStatisticsMean);
481impl ImageStatisticsMean {
482    #[must_use]
483    pub fn new(device: &MetalDevice) -> Option<Self> {
484        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
485        let ptr = unsafe { ffi::mps_image_statistics_mean_new(device.as_ptr()) };
486        if ptr.is_null() {
487            None
488        } else {
489            Some(Self { ptr })
490        }
491    }
492}
493impl_unary_methods!(ImageStatisticsMean);
494
495opaque_handle!(ImageReduceRowMin);
496impl ImageReduceRowMin {
497    #[must_use]
498    pub fn new(device: &MetalDevice) -> Option<Self> {
499        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
500        let ptr = unsafe { ffi::mps_image_reduce_row_min_new(device.as_ptr()) };
501        if ptr.is_null() {
502            None
503        } else {
504            Some(Self { ptr })
505        }
506    }
507}
508impl_unary_methods!(ImageReduceRowMin);
509
510opaque_handle!(ImageReduceRowMax);
511impl ImageReduceRowMax {
512    #[must_use]
513    pub fn new(device: &MetalDevice) -> Option<Self> {
514        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
515        let ptr = unsafe { ffi::mps_image_reduce_row_max_new(device.as_ptr()) };
516        if ptr.is_null() {
517            None
518        } else {
519            Some(Self { ptr })
520        }
521    }
522}
523impl_unary_methods!(ImageReduceRowMax);
524
525opaque_handle!(ImageReduceRowMean);
526impl ImageReduceRowMean {
527    #[must_use]
528    pub fn new(device: &MetalDevice) -> Option<Self> {
529        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
530        let ptr = unsafe { ffi::mps_image_reduce_row_mean_new(device.as_ptr()) };
531        if ptr.is_null() {
532            None
533        } else {
534            Some(Self { ptr })
535        }
536    }
537}
538impl_unary_methods!(ImageReduceRowMean);
539
540opaque_handle!(ImageReduceRowSum);
541impl ImageReduceRowSum {
542    #[must_use]
543    pub fn new(device: &MetalDevice) -> Option<Self> {
544        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
545        let ptr = unsafe { ffi::mps_image_reduce_row_sum_new(device.as_ptr()) };
546        if ptr.is_null() {
547            None
548        } else {
549            Some(Self { ptr })
550        }
551    }
552}
553impl_unary_methods!(ImageReduceRowSum);
554
555opaque_handle!(ImageAdd);
556impl ImageAdd {
557    #[must_use]
558    pub fn new(device: &MetalDevice) -> Option<Self> {
559        // SAFETY: `device` exposes a valid `MTLDevice` pointer.
560        let ptr = unsafe { ffi::mps_image_add_new(device.as_ptr()) };
561        if ptr.is_null() {
562            None
563        } else {
564            Some(Self { ptr })
565        }
566    }
567
568    /// Set `primaryScale`, `secondaryScale`, and `bias` in one call.
569    pub fn set_scales(&self, primary_scale: f32, secondary_scale: f32, bias: f32) {
570        // SAFETY: The kernel pointer is valid for the duration of the call.
571        unsafe {
572            ffi::mps_image_arithmetic_set_scales_bias(
573                self.ptr,
574                primary_scale,
575                secondary_scale,
576                bias,
577            );
578        };
579    }
580
581    /// Clamp arithmetic results to the closed interval `[minimum_value, maximum_value]`.
582    pub fn set_clamp(&self, minimum_value: f32, maximum_value: f32) {
583        // SAFETY: The kernel pointer is valid for the duration of the call.
584        unsafe { ffi::mps_image_arithmetic_set_clamp(self.ptr, minimum_value, maximum_value) };
585    }
586}
587impl_binary_methods!(ImageAdd);
588
589/// Convenience wrapper for `scale-and-add` semantics implemented with `MPSImageAdd`.
590pub struct ImageScaleAndAdd {
591    inner: ImageAdd,
592}
593
594impl ImageScaleAndAdd {
595    /// Build an image add kernel with non-unit primary/secondary scales.
596    #[must_use]
597    pub fn new(
598        device: &MetalDevice,
599        primary_scale: f32,
600        secondary_scale: f32,
601        bias: f32,
602    ) -> Option<Self> {
603        let inner = ImageAdd::new(device)?;
604        inner.set_scales(primary_scale, secondary_scale, bias);
605        Some(Self { inner })
606    }
607
608    #[must_use]
609    pub const fn as_ptr(&self) -> *mut c_void {
610        self.inner.as_ptr()
611    }
612
613    pub fn encode_image(
614        &self,
615        command_buffer: &CommandBuffer,
616        primary: &Image,
617        secondary: &Image,
618        destination: &Image,
619    ) {
620        self.inner
621            .encode_image(command_buffer, primary, secondary, destination);
622    }
623
624    pub fn encode_texture(
625        &self,
626        command_buffer: &CommandBuffer,
627        primary: &MetalTexture,
628        secondary: &MetalTexture,
629        destination: &MetalTexture,
630    ) {
631        self.inner
632            .encode_texture(command_buffer, primary, secondary, destination);
633    }
634
635    pub fn set_primary_edge_mode(&self, edge_mode: usize) {
636        self.inner.set_primary_edge_mode(edge_mode);
637    }
638
639    pub fn set_secondary_edge_mode(&self, edge_mode: usize) {
640        self.inner.set_secondary_edge_mode(edge_mode);
641    }
642
643    pub fn set_clip_rect(&self, region: ImageRegion) {
644        self.inner.set_clip_rect(region);
645    }
646
647    pub fn set_clamp(&self, minimum_value: f32, maximum_value: f32) {
648        self.inner.set_clamp(minimum_value, maximum_value);
649    }
650}