Skip to main content

apple_mps/
image.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use apple_metal::{storage_mode, texture_usage, MetalDevice, MetalTexture};
4use core::ffi::c_void;
5use core::ptr;
6
7/// `MPSImageFeatureChannelFormat` constants.
8pub mod feature_channel_format {
9    pub const NONE: usize = 0;
10    pub const UNORM8: usize = 1;
11    pub const UNORM16: usize = 2;
12    pub const FLOAT16: usize = 3;
13    pub const FLOAT32: usize = 4;
14}
15
16/// `MPSDataLayout` constants.
17#[allow(non_upper_case_globals)]
18pub mod image_layout {
19    pub const HEIGHTxWIDTHxFEATURE_CHANNELS: usize = 0;
20    pub const FEATURE_CHANNELSxHEIGHTxWIDTH: usize = 1;
21}
22
23/// `MPSImageEdgeMode` constants.
24pub mod image_edge_mode {
25    pub const ZERO: usize = 0;
26    pub const CLAMP: usize = 1;
27}
28
29/// `MPSKernelOptions` constants.
30pub mod kernel_options {
31    pub const NONE: u32 = 0;
32    pub const SKIP_API_VALIDATION: u32 = 1 << 0;
33    pub const ALLOW_REDUCED_PRECISION: u32 = 1 << 1;
34    pub const DISABLE_INTERNAL_TILING: u32 = 1 << 2;
35    pub const INSERT_DEBUG_GROUPS: u32 = 1 << 3;
36    pub const VERBOSE: u32 = 1 << 4;
37}
38
39/// Plain-Rust configuration for building a `MPSImageDescriptor` on the Swift side.
40#[derive(Debug, Clone, Copy)]
41pub struct ImageDescriptor {
42    pub channel_format: usize,
43    pub width: usize,
44    pub height: usize,
45    pub feature_channels: usize,
46    pub number_of_images: usize,
47    pub usage: usize,
48    pub storage_mode: usize,
49}
50
51impl ImageDescriptor {
52    /// Create a single-image descriptor with sensible defaults for read/write image processing.
53    #[must_use]
54    pub const fn new(
55        width: usize,
56        height: usize,
57        feature_channels: usize,
58        channel_format: usize,
59    ) -> Self {
60        Self {
61            channel_format,
62            width,
63            height,
64            feature_channels,
65            number_of_images: 1,
66            usage: texture_usage::SHADER_READ | texture_usage::SHADER_WRITE,
67            storage_mode: storage_mode::MANAGED,
68        }
69    }
70}
71
72/// Rectangular region used for image transfer or clip-rect configuration.
73#[derive(Debug, Clone, Copy)]
74pub struct ImageRegion {
75    pub x: usize,
76    pub y: usize,
77    pub z: usize,
78    pub width: usize,
79    pub height: usize,
80    pub depth: usize,
81}
82
83impl ImageRegion {
84    /// Construct an arbitrary region.
85    #[must_use]
86    pub const fn new(
87        x: usize,
88        y: usize,
89        z: usize,
90        width: usize,
91        height: usize,
92        depth: usize,
93    ) -> Self {
94        Self {
95            x,
96            y,
97            z,
98            width,
99            height,
100            depth,
101        }
102    }
103
104    /// Region covering the full first image slice.
105    #[must_use]
106    pub const fn whole(width: usize, height: usize) -> Self {
107        Self::new(0, 0, 0, width, height, 1)
108    }
109}
110
111/// `MPSImageReadWriteParams` values.
112#[derive(Debug, Clone, Copy)]
113pub struct ImageReadWriteParams {
114    pub feature_channel_offset: usize,
115    pub feature_channel_count: usize,
116}
117
118impl ImageReadWriteParams {
119    /// Create a parameter block describing the feature-channel window to transfer.
120    #[must_use]
121    pub const fn new(feature_channel_offset: usize, feature_channel_count: usize) -> Self {
122        Self {
123            feature_channel_offset,
124            feature_channel_count,
125        }
126    }
127
128    /// Transfer all feature channels starting at zero.
129    #[must_use]
130    pub const fn all(feature_channels: usize) -> Self {
131        Self::new(0, feature_channels)
132    }
133}
134
135/// Safe owner for an Objective-C `MPSImage`.
136pub struct Image {
137    ptr: *mut c_void,
138}
139
140unsafe impl Send for Image {}
141unsafe impl Sync for Image {}
142
143impl Drop for Image {
144    fn drop(&mut self) {
145        if !self.ptr.is_null() {
146            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
147            unsafe { ffi::mps_object_release(self.ptr) };
148            self.ptr = ptr::null_mut();
149        }
150    }
151}
152
153impl Image {
154    /// Allocate a lazily backed `MPSImage` on `device`.
155    #[must_use]
156    pub fn new(device: &MetalDevice, descriptor: ImageDescriptor) -> Option<Self> {
157        // SAFETY: All pointers originate from safe wrappers and the scalar arguments are POD.
158        let ptr = unsafe {
159            ffi::mps_image_new_with_descriptor(
160                device.as_ptr(),
161                descriptor.channel_format,
162                descriptor.width,
163                descriptor.height,
164                descriptor.feature_channels,
165                descriptor.number_of_images,
166                descriptor.usage,
167                descriptor.storage_mode,
168            )
169        };
170        if ptr.is_null() {
171            None
172        } else {
173            Some(Self { ptr })
174        }
175    }
176
177    /// Wrap an existing Metal texture in an `MPSImage`.
178    #[must_use]
179    pub fn from_texture(texture: &MetalTexture, feature_channels: usize) -> Option<Self> {
180        // SAFETY: `texture` is a valid `MTLTexture` pointer from `apple-metal`.
181        let ptr = unsafe { ffi::mps_image_new_with_texture(texture.as_ptr(), feature_channels) };
182        if ptr.is_null() {
183            None
184        } else {
185            Some(Self { ptr })
186        }
187    }
188
189    /// Raw `MPSImage` pointer.
190    #[must_use]
191    pub const fn as_ptr(&self) -> *mut c_void {
192        self.ptr
193    }
194
195    #[must_use]
196    pub(crate) const unsafe fn from_raw(ptr: *mut c_void) -> Self {
197        Self { ptr }
198    }
199
200    /// Image width in pixels.
201    #[must_use]
202    pub fn width(&self) -> usize {
203        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
204        unsafe { ffi::mps_image_width(self.ptr) }
205    }
206
207    /// Image height in pixels.
208    #[must_use]
209    pub fn height(&self) -> usize {
210        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
211        unsafe { ffi::mps_image_height(self.ptr) }
212    }
213
214    /// Number of feature channels per pixel.
215    #[must_use]
216    pub fn feature_channels(&self) -> usize {
217        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
218        unsafe { ffi::mps_image_feature_channels(self.ptr) }
219    }
220
221    /// Number of images stored in the backing texture array.
222    #[must_use]
223    pub fn number_of_images(&self) -> usize {
224        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
225        unsafe { ffi::mps_image_number_of_images(self.ptr) }
226    }
227
228    /// Bytes between neighboring pixels in storage order.
229    #[must_use]
230    pub fn pixel_size(&self) -> usize {
231        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
232        unsafe { ffi::mps_image_pixel_size(self.ptr) }
233    }
234
235    /// Underlying `MTLPixelFormat` raw value.
236    #[must_use]
237    pub fn pixel_format(&self) -> usize {
238        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
239        unsafe { ffi::mps_image_pixel_format(self.ptr) }
240    }
241
242    /// Convenience region covering the full first image.
243    #[must_use]
244    pub fn whole_region(&self) -> ImageRegion {
245        ImageRegion::whole(self.width(), self.height())
246    }
247
248    /// Read bytes out of the image into a caller-provided buffer.
249    pub fn read_bytes(
250        &self,
251        dst: &mut [u8],
252        data_layout: usize,
253        bytes_per_row: usize,
254        region: ImageRegion,
255        params: ImageReadWriteParams,
256        image_index: usize,
257    ) -> Result<()> {
258        let expected = required_bytes(data_layout, bytes_per_row, region, params);
259        if dst.len() < expected {
260            return Err(Error::InvalidLength {
261                expected,
262                actual: dst.len(),
263            });
264        }
265
266        // SAFETY: `dst` is valid for writes of at least `expected` bytes and all handles are valid.
267        let _ = unsafe {
268            ffi::mps_image_read_bytes(
269                self.ptr,
270                dst.as_mut_ptr().cast(),
271                data_layout,
272                bytes_per_row,
273                region.x,
274                region.y,
275                region.z,
276                region.width,
277                region.height,
278                region.depth,
279                params.feature_channel_offset,
280                params.feature_channel_count,
281                image_index,
282            )
283        };
284        Ok(())
285    }
286
287    /// Write bytes into the image from a caller-provided buffer.
288    pub fn write_bytes(
289        &self,
290        src: &[u8],
291        data_layout: usize,
292        bytes_per_row: usize,
293        region: ImageRegion,
294        params: ImageReadWriteParams,
295        image_index: usize,
296    ) -> Result<()> {
297        let expected = required_bytes(data_layout, bytes_per_row, region, params);
298        if src.len() < expected {
299            return Err(Error::InvalidLength {
300                expected,
301                actual: src.len(),
302            });
303        }
304
305        // SAFETY: `src` is valid for reads of at least `expected` bytes and all handles are valid.
306        let _ = unsafe {
307            ffi::mps_image_write_bytes(
308                self.ptr,
309                src.as_ptr().cast(),
310                data_layout,
311                bytes_per_row,
312                region.x,
313                region.y,
314                region.z,
315                region.width,
316                region.height,
317                region.depth,
318                params.feature_channel_offset,
319                params.feature_channel_count,
320                image_index,
321            )
322        };
323        Ok(())
324    }
325
326    /// Read the first image slice as tightly packed float32 HWC data.
327    pub fn read_f32(&self) -> Result<Vec<f32>> {
328        let len = self.width() * self.height() * self.feature_channels();
329        let mut data = vec![0.0_f32; len];
330        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
331        // SAFETY: `data` is a contiguous `Vec<f32>` with exactly `len * size_of::<f32>()` bytes.
332        let bytes = unsafe {
333            core::slice::from_raw_parts_mut(
334                data.as_mut_ptr().cast::<u8>(),
335                core::mem::size_of_val(data.as_slice()),
336            )
337        };
338        self.read_bytes(
339            bytes,
340            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
341            bytes_per_row,
342            self.whole_region(),
343            ImageReadWriteParams::all(self.feature_channels()),
344            0,
345        )?;
346        Ok(data)
347    }
348
349    /// Write tightly packed float32 HWC data into the first image slice.
350    pub fn write_f32(&self, data: &[f32]) -> Result<()> {
351        let expected = self.width() * self.height() * self.feature_channels();
352        if data.len() != expected {
353            return Err(Error::InvalidLength {
354                expected: expected * core::mem::size_of::<f32>(),
355                actual: core::mem::size_of_val(data),
356            });
357        }
358
359        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
360        // SAFETY: `data` is a contiguous slice of `f32`, which may be viewed as bytes.
361        let bytes = unsafe {
362            core::slice::from_raw_parts(data.as_ptr().cast::<u8>(), core::mem::size_of_val(data))
363        };
364        self.write_bytes(
365            bytes,
366            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
367            bytes_per_row,
368            self.whole_region(),
369            ImageReadWriteParams::all(self.feature_channels()),
370            0,
371        )
372    }
373}
374
375fn required_bytes(
376    data_layout: usize,
377    bytes_per_row: usize,
378    region: ImageRegion,
379    params: ImageReadWriteParams,
380) -> usize {
381    let rows = region.height.saturating_mul(region.depth);
382    let base = bytes_per_row.saturating_mul(rows);
383    if data_layout == image_layout::FEATURE_CHANNELSxHEIGHTxWIDTH {
384        base.saturating_mul(params.feature_channel_count.max(1))
385    } else {
386        base
387    }
388}