Skip to main content

apple_mps/
image.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use apple_metal::{storage_mode, texture_usage, MetalDevice, MetalTexture};
4use core::ffi::c_void;
5use core::ptr;
6
7/// `MPSImageFeatureChannelFormat` constants.
8pub mod feature_channel_format {
9    /// Wraps a `MPSImageFeatureChannelFormat` raw value.
10    pub const NONE: usize = 0;
11    /// Wraps a `MPSImageFeatureChannelFormat` raw value.
12    pub const UNORM8: usize = 1;
13    /// Wraps a `MPSImageFeatureChannelFormat` raw value.
14    pub const UNORM16: usize = 2;
15    /// Wraps a `MPSImageFeatureChannelFormat` raw value.
16    pub const FLOAT16: usize = 3;
17    /// Wraps a `MPSImageFeatureChannelFormat` raw value.
18    pub const FLOAT32: usize = 4;
19}
20
21/// `MPSDataLayout` constants.
22#[allow(non_upper_case_globals)]
23pub mod image_layout {
24    /// Wraps a `MPSDataLayout` raw value.
25    pub const HEIGHTxWIDTHxFEATURE_CHANNELS: usize = 0;
26    /// Wraps a `MPSDataLayout` raw value.
27    pub const FEATURE_CHANNELSxHEIGHTxWIDTH: usize = 1;
28}
29
30/// `MPSImageEdgeMode` constants.
31pub mod image_edge_mode {
32    /// Wraps a `MPSImageEdgeMode` raw value.
33    pub const ZERO: usize = 0;
34    /// Wraps a `MPSImageEdgeMode` raw value.
35    pub const CLAMP: usize = 1;
36}
37
38/// `MPSKernelOptions` constants.
39pub mod kernel_options {
40    /// Wraps a `MPSKernelOptions` raw value.
41    pub const NONE: u32 = 0;
42    /// Wraps a `MPSKernelOptions` raw value.
43    pub const SKIP_API_VALIDATION: u32 = 1 << 0;
44    /// Wraps a `MPSKernelOptions` raw value.
45    pub const ALLOW_REDUCED_PRECISION: u32 = 1 << 1;
46    /// Wraps a `MPSKernelOptions` raw value.
47    pub const DISABLE_INTERNAL_TILING: u32 = 1 << 2;
48    /// Wraps a `MPSKernelOptions` raw value.
49    pub const INSERT_DEBUG_GROUPS: u32 = 1 << 3;
50    /// Wraps a `MPSKernelOptions` raw value.
51    pub const VERBOSE: u32 = 1 << 4;
52}
53
54/// Plain-Rust configuration for building a `MPSImageDescriptor` on the Swift side.
55#[derive(Debug, Clone, Copy)]
56pub struct ImageDescriptor {
57    /// Corresponds to the `channel_format` field on `MPSImageDescriptor`.
58    pub channel_format: usize,
59    /// Corresponds to the `width` field on `MPSImageDescriptor`.
60    pub width: usize,
61    /// Corresponds to the `height` field on `MPSImageDescriptor`.
62    pub height: usize,
63    /// Corresponds to the `feature_channels` field on `MPSImageDescriptor`.
64    pub feature_channels: usize,
65    /// Corresponds to the `number_of_images` field on `MPSImageDescriptor`.
66    pub number_of_images: usize,
67    /// Corresponds to the `usage` field on `MPSImageDescriptor`.
68    pub usage: usize,
69    /// Corresponds to the `storage_mode` field on `MPSImageDescriptor`.
70    pub storage_mode: usize,
71}
72
73impl ImageDescriptor {
74    /// Create a single-image descriptor with sensible defaults for read/write image processing.
75    #[must_use]
76    pub const fn new(
77        width: usize,
78        height: usize,
79        feature_channels: usize,
80        channel_format: usize,
81    ) -> Self {
82        Self {
83            channel_format,
84            width,
85            height,
86            feature_channels,
87            number_of_images: 1,
88            usage: texture_usage::SHADER_READ | texture_usage::SHADER_WRITE,
89            storage_mode: storage_mode::MANAGED,
90        }
91    }
92}
93
94/// Rectangular region used for image transfer or clip-rect configuration.
95#[derive(Debug, Clone, Copy)]
96pub struct ImageRegion {
97    /// Corresponds to the `x` field on `MPSRegion`.
98    pub x: usize,
99    /// Corresponds to the `y` field on `MPSRegion`.
100    pub y: usize,
101    /// Corresponds to the `z` field on `MPSRegion`.
102    pub z: usize,
103    /// Corresponds to the `width` field on `MPSRegion`.
104    pub width: usize,
105    /// Corresponds to the `height` field on `MPSRegion`.
106    pub height: usize,
107    /// Corresponds to the `depth` field on `MPSRegion`.
108    pub depth: usize,
109}
110
111impl ImageRegion {
112    /// Construct an arbitrary region.
113    #[must_use]
114    pub const fn new(
115        x: usize,
116        y: usize,
117        z: usize,
118        width: usize,
119        height: usize,
120        depth: usize,
121    ) -> Self {
122        Self {
123            x,
124            y,
125            z,
126            width,
127            height,
128            depth,
129        }
130    }
131
132    /// Region covering the full first image slice.
133    #[must_use]
134    pub const fn whole(width: usize, height: usize) -> Self {
135        Self::new(0, 0, 0, width, height, 1)
136    }
137}
138
139/// `MPSImageReadWriteParams` values.
140#[derive(Debug, Clone, Copy)]
141pub struct ImageReadWriteParams {
142    /// Corresponds to the `feature_channel_offset` field on `MPSImageReadWriteParams`.
143    pub feature_channel_offset: usize,
144    /// Corresponds to the `feature_channel_count` field on `MPSImageReadWriteParams`.
145    pub feature_channel_count: usize,
146}
147
148impl ImageReadWriteParams {
149    /// Create a parameter block describing the feature-channel window to transfer.
150    #[must_use]
151    pub const fn new(feature_channel_offset: usize, feature_channel_count: usize) -> Self {
152        Self {
153            feature_channel_offset,
154            feature_channel_count,
155        }
156    }
157
158    /// Transfer all feature channels starting at zero.
159    #[must_use]
160    pub const fn all(feature_channels: usize) -> Self {
161        Self::new(0, feature_channels)
162    }
163}
164
165/// Safe owner for an Objective-C `MPSImage`.
166pub struct Image {
167    ptr: *mut c_void,
168}
169
170// SAFETY: MPSImage pointers are thread-safe Objective-C objects.
171unsafe impl Send for Image {}
172// SAFETY: MPSImage pointers are thread-safe Objective-C objects.
173unsafe impl Sync for Image {}
174
175impl Drop for Image {
176    fn drop(&mut self) {
177        if !self.ptr.is_null() {
178            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
179            unsafe { ffi::mps_object_release(self.ptr) };
180            self.ptr = ptr::null_mut();
181        }
182    }
183}
184
185impl Image {
186    /// Allocate a lazily backed `MPSImage` on `device`.
187    #[must_use]
188    pub fn new(device: &MetalDevice, descriptor: ImageDescriptor) -> Option<Self> {
189        // SAFETY: All pointers originate from safe wrappers and the scalar arguments are POD.
190        let ptr = unsafe {
191            ffi::mps_image_new_with_descriptor(
192                device.as_ptr(),
193                descriptor.channel_format,
194                descriptor.width,
195                descriptor.height,
196                descriptor.feature_channels,
197                descriptor.number_of_images,
198                descriptor.usage,
199                descriptor.storage_mode,
200            )
201        };
202        if ptr.is_null() {
203            None
204        } else {
205            Some(Self { ptr })
206        }
207    }
208
209    /// Wrap an existing Metal texture in an `MPSImage`.
210    #[must_use]
211    pub fn from_texture(texture: &MetalTexture, feature_channels: usize) -> Option<Self> {
212        // SAFETY: `texture` is a valid `MTLTexture` pointer from `apple-metal`.
213        let ptr = unsafe { ffi::mps_image_new_with_texture(texture.as_ptr(), feature_channels) };
214        if ptr.is_null() {
215            None
216        } else {
217            Some(Self { ptr })
218        }
219    }
220
221    /// Raw `MPSImage` pointer.
222    #[must_use]
223    pub const fn as_ptr(&self) -> *mut c_void {
224        self.ptr
225    }
226
227    #[must_use]
228    pub(crate) const unsafe fn from_raw(ptr: *mut c_void) -> Self {
229        // SAFETY: Caller must ensure `ptr` is a valid +1 retained MPSImage pointer.
230        // SAFETY: Caller must ensure `ptr` is a valid +1 retained MPSImage pointer.
231        Self { ptr }
232    }
233
234    /// Image width in pixels.
235    #[must_use]
236    pub fn width(&self) -> usize {
237        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
238        unsafe { ffi::mps_image_width(self.ptr) }
239    }
240
241    /// Image height in pixels.
242    #[must_use]
243    pub fn height(&self) -> usize {
244        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
245        unsafe { ffi::mps_image_height(self.ptr) }
246    }
247
248    /// Number of feature channels per pixel.
249    #[must_use]
250    pub fn feature_channels(&self) -> usize {
251        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
252        unsafe { ffi::mps_image_feature_channels(self.ptr) }
253    }
254
255    /// Number of images stored in the backing texture array.
256    #[must_use]
257    pub fn number_of_images(&self) -> usize {
258        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
259        unsafe { ffi::mps_image_number_of_images(self.ptr) }
260    }
261
262    /// Bytes between neighboring pixels in storage order.
263    #[must_use]
264    pub fn pixel_size(&self) -> usize {
265        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
266        unsafe { ffi::mps_image_pixel_size(self.ptr) }
267    }
268
269    /// Underlying `MTLPixelFormat` raw value.
270    #[must_use]
271    pub fn pixel_format(&self) -> usize {
272        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
273        unsafe { ffi::mps_image_pixel_format(self.ptr) }
274    }
275
276    /// Convenience region covering the full first image.
277    #[must_use]
278    pub fn whole_region(&self) -> ImageRegion {
279        ImageRegion::whole(self.width(), self.height())
280    }
281
282    /// Read bytes out of the image into a caller-provided buffer.
283    pub fn read_bytes(
284        &self,
285        dst: &mut [u8],
286        data_layout: usize,
287        bytes_per_row: usize,
288        region: ImageRegion,
289        params: ImageReadWriteParams,
290        image_index: usize,
291    ) -> Result<()> {
292        let expected = required_bytes(data_layout, bytes_per_row, region, params);
293        if dst.len() < expected {
294            return Err(Error::InvalidLength {
295                expected,
296                actual: dst.len(),
297            });
298        }
299
300        // SAFETY: `dst` is valid for writes of at least `expected` bytes and all handles are valid.
301        let _ = unsafe {
302            ffi::mps_image_read_bytes(
303                self.ptr,
304                dst.as_mut_ptr().cast(),
305                data_layout,
306                bytes_per_row,
307                region.x,
308                region.y,
309                region.z,
310                region.width,
311                region.height,
312                region.depth,
313                params.feature_channel_offset,
314                params.feature_channel_count,
315                image_index,
316            )
317        };
318        Ok(())
319    }
320
321    /// Write bytes into the image from a caller-provided buffer.
322    pub fn write_bytes(
323        &self,
324        src: &[u8],
325        data_layout: usize,
326        bytes_per_row: usize,
327        region: ImageRegion,
328        params: ImageReadWriteParams,
329        image_index: usize,
330    ) -> Result<()> {
331        let expected = required_bytes(data_layout, bytes_per_row, region, params);
332        if src.len() < expected {
333            return Err(Error::InvalidLength {
334                expected,
335                actual: src.len(),
336            });
337        }
338
339        // SAFETY: `src` is valid for reads of at least `expected` bytes and all handles are valid.
340        let _ = unsafe {
341            ffi::mps_image_write_bytes(
342                self.ptr,
343                src.as_ptr().cast(),
344                data_layout,
345                bytes_per_row,
346                region.x,
347                region.y,
348                region.z,
349                region.width,
350                region.height,
351                region.depth,
352                params.feature_channel_offset,
353                params.feature_channel_count,
354                image_index,
355            )
356        };
357        Ok(())
358    }
359
360    /// Read the first image slice as tightly packed float32 HWC data.
361    pub fn read_f32(&self) -> Result<Vec<f32>> {
362        let len = self.width() * self.height() * self.feature_channels();
363        let mut data = vec![0.0_f32; len];
364        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
365        // SAFETY: `data` is a contiguous `Vec<f32>` with exactly `len * size_of::<f32>()` bytes.
366        let bytes = unsafe {
367            core::slice::from_raw_parts_mut(
368                data.as_mut_ptr().cast::<u8>(),
369                core::mem::size_of_val(data.as_slice()),
370            )
371        };
372        self.read_bytes(
373            bytes,
374            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
375            bytes_per_row,
376            self.whole_region(),
377            ImageReadWriteParams::all(self.feature_channels()),
378            0,
379        )?;
380        Ok(data)
381    }
382
383    /// Write tightly packed float32 HWC data into the first image slice.
384    pub fn write_f32(&self, data: &[f32]) -> Result<()> {
385        let expected = self.width() * self.height() * self.feature_channels();
386        if data.len() != expected {
387            return Err(Error::InvalidLength {
388                expected: expected * core::mem::size_of::<f32>(),
389                actual: core::mem::size_of_val(data),
390            });
391        }
392
393        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
394        // SAFETY: `data` is a contiguous slice of `f32`, which may be viewed as bytes.
395        let bytes = unsafe {
396            core::slice::from_raw_parts(data.as_ptr().cast::<u8>(), core::mem::size_of_val(data))
397        };
398        self.write_bytes(
399            bytes,
400            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
401            bytes_per_row,
402            self.whole_region(),
403            ImageReadWriteParams::all(self.feature_channels()),
404            0,
405        )
406    }
407}
408
409#[doc(hidden)]
410pub use crate::generated::image::*;
411
412fn required_bytes(
413    data_layout: usize,
414    bytes_per_row: usize,
415    region: ImageRegion,
416    params: ImageReadWriteParams,
417) -> usize {
418    let rows = region.height.saturating_mul(region.depth);
419    let base = bytes_per_row.saturating_mul(rows);
420    if data_layout == image_layout::FEATURE_CHANNELSxHEIGHTxWIDTH {
421        base.saturating_mul(params.feature_channel_count.max(1))
422    } else {
423        base
424    }
425}