Skip to main content

apple_mps/
image.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use apple_metal::{storage_mode, texture_usage, MetalDevice, MetalTexture};
4use core::ffi::c_void;
5use core::ptr;
6
7/// `MPSImageFeatureChannelFormat` constants.
8pub mod feature_channel_format {
9    pub const NONE: usize = 0;
10    pub const UNORM8: usize = 1;
11    pub const UNORM16: usize = 2;
12    pub const FLOAT16: usize = 3;
13    pub const FLOAT32: usize = 4;
14}
15
16/// `MPSDataLayout` constants.
17#[allow(non_upper_case_globals)]
18pub mod image_layout {
19    pub const HEIGHTxWIDTHxFEATURE_CHANNELS: usize = 0;
20    pub const FEATURE_CHANNELSxHEIGHTxWIDTH: usize = 1;
21}
22
23/// `MPSImageEdgeMode` constants.
24pub mod image_edge_mode {
25    pub const ZERO: usize = 0;
26    pub const CLAMP: usize = 1;
27}
28
29/// `MPSKernelOptions` constants.
30pub mod kernel_options {
31    pub const NONE: u32 = 0;
32    pub const SKIP_API_VALIDATION: u32 = 1 << 0;
33    pub const ALLOW_REDUCED_PRECISION: u32 = 1 << 1;
34    pub const DISABLE_INTERNAL_TILING: u32 = 1 << 2;
35    pub const INSERT_DEBUG_GROUPS: u32 = 1 << 3;
36    pub const VERBOSE: u32 = 1 << 4;
37}
38
39/// Plain-Rust configuration for building a `MPSImageDescriptor` on the Swift side.
40#[derive(Debug, Clone, Copy)]
41pub struct ImageDescriptor {
42    pub channel_format: usize,
43    pub width: usize,
44    pub height: usize,
45    pub feature_channels: usize,
46    pub number_of_images: usize,
47    pub usage: usize,
48    pub storage_mode: usize,
49}
50
51impl ImageDescriptor {
52    /// Create a single-image descriptor with sensible defaults for read/write image processing.
53    #[must_use]
54    pub const fn new(
55        width: usize,
56        height: usize,
57        feature_channels: usize,
58        channel_format: usize,
59    ) -> Self {
60        Self {
61            channel_format,
62            width,
63            height,
64            feature_channels,
65            number_of_images: 1,
66            usage: texture_usage::SHADER_READ | texture_usage::SHADER_WRITE,
67            storage_mode: storage_mode::MANAGED,
68        }
69    }
70}
71
72/// Rectangular region used for image transfer or clip-rect configuration.
73#[derive(Debug, Clone, Copy)]
74pub struct ImageRegion {
75    pub x: usize,
76    pub y: usize,
77    pub z: usize,
78    pub width: usize,
79    pub height: usize,
80    pub depth: usize,
81}
82
83impl ImageRegion {
84    /// Construct an arbitrary region.
85    #[must_use]
86    pub const fn new(
87        x: usize,
88        y: usize,
89        z: usize,
90        width: usize,
91        height: usize,
92        depth: usize,
93    ) -> Self {
94        Self {
95            x,
96            y,
97            z,
98            width,
99            height,
100            depth,
101        }
102    }
103
104    /// Region covering the full first image slice.
105    #[must_use]
106    pub const fn whole(width: usize, height: usize) -> Self {
107        Self::new(0, 0, 0, width, height, 1)
108    }
109}
110
111/// `MPSImageReadWriteParams` values.
112#[derive(Debug, Clone, Copy)]
113pub struct ImageReadWriteParams {
114    pub feature_channel_offset: usize,
115    pub feature_channel_count: usize,
116}
117
118impl ImageReadWriteParams {
119    /// Create a parameter block describing the feature-channel window to transfer.
120    #[must_use]
121    pub const fn new(feature_channel_offset: usize, feature_channel_count: usize) -> Self {
122        Self {
123            feature_channel_offset,
124            feature_channel_count,
125        }
126    }
127
128    /// Transfer all feature channels starting at zero.
129    #[must_use]
130    pub const fn all(feature_channels: usize) -> Self {
131        Self::new(0, feature_channels)
132    }
133}
134
135/// Safe owner for an Objective-C `MPSImage`.
136pub struct Image {
137    ptr: *mut c_void,
138}
139
140unsafe impl Send for Image {}
141unsafe impl Sync for Image {}
142
143impl Drop for Image {
144    fn drop(&mut self) {
145        if !self.ptr.is_null() {
146            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
147            unsafe { ffi::mps_object_release(self.ptr) };
148            self.ptr = ptr::null_mut();
149        }
150    }
151}
152
153impl Image {
154    /// Allocate a lazily backed `MPSImage` on `device`.
155    #[must_use]
156    pub fn new(device: &MetalDevice, descriptor: ImageDescriptor) -> Option<Self> {
157        // SAFETY: All pointers originate from safe wrappers and the scalar arguments are POD.
158        let ptr = unsafe {
159            ffi::mps_image_new_with_descriptor(
160                device.as_ptr(),
161                descriptor.channel_format,
162                descriptor.width,
163                descriptor.height,
164                descriptor.feature_channels,
165                descriptor.number_of_images,
166                descriptor.usage,
167                descriptor.storage_mode,
168            )
169        };
170        if ptr.is_null() {
171            None
172        } else {
173            Some(Self { ptr })
174        }
175    }
176
177    /// Wrap an existing Metal texture in an `MPSImage`.
178    #[must_use]
179    pub fn from_texture(texture: &MetalTexture, feature_channels: usize) -> Option<Self> {
180        // SAFETY: `texture` is a valid `MTLTexture` pointer from `apple-metal`.
181        let ptr = unsafe { ffi::mps_image_new_with_texture(texture.as_ptr(), feature_channels) };
182        if ptr.is_null() {
183            None
184        } else {
185            Some(Self { ptr })
186        }
187    }
188
189    /// Raw `MPSImage` pointer.
190    #[must_use]
191    pub const fn as_ptr(&self) -> *mut c_void {
192        self.ptr
193    }
194
195    /// Image width in pixels.
196    #[must_use]
197    pub fn width(&self) -> usize {
198        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
199        unsafe { ffi::mps_image_width(self.ptr) }
200    }
201
202    /// Image height in pixels.
203    #[must_use]
204    pub fn height(&self) -> usize {
205        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
206        unsafe { ffi::mps_image_height(self.ptr) }
207    }
208
209    /// Number of feature channels per pixel.
210    #[must_use]
211    pub fn feature_channels(&self) -> usize {
212        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
213        unsafe { ffi::mps_image_feature_channels(self.ptr) }
214    }
215
216    /// Number of images stored in the backing texture array.
217    #[must_use]
218    pub fn number_of_images(&self) -> usize {
219        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
220        unsafe { ffi::mps_image_number_of_images(self.ptr) }
221    }
222
223    /// Bytes between neighboring pixels in storage order.
224    #[must_use]
225    pub fn pixel_size(&self) -> usize {
226        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
227        unsafe { ffi::mps_image_pixel_size(self.ptr) }
228    }
229
230    /// Underlying `MTLPixelFormat` raw value.
231    #[must_use]
232    pub fn pixel_format(&self) -> usize {
233        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
234        unsafe { ffi::mps_image_pixel_format(self.ptr) }
235    }
236
237    /// Convenience region covering the full first image.
238    #[must_use]
239    pub fn whole_region(&self) -> ImageRegion {
240        ImageRegion::whole(self.width(), self.height())
241    }
242
243    /// Read bytes out of the image into a caller-provided buffer.
244    pub fn read_bytes(
245        &self,
246        dst: &mut [u8],
247        data_layout: usize,
248        bytes_per_row: usize,
249        region: ImageRegion,
250        params: ImageReadWriteParams,
251        image_index: usize,
252    ) -> Result<()> {
253        let expected = required_bytes(data_layout, bytes_per_row, region, params);
254        if dst.len() < expected {
255            return Err(Error::InvalidLength {
256                expected,
257                actual: dst.len(),
258            });
259        }
260
261        // SAFETY: `dst` is valid for writes of at least `expected` bytes and all handles are valid.
262        let _ = unsafe {
263            ffi::mps_image_read_bytes(
264                self.ptr,
265                dst.as_mut_ptr().cast(),
266                data_layout,
267                bytes_per_row,
268                region.x,
269                region.y,
270                region.z,
271                region.width,
272                region.height,
273                region.depth,
274                params.feature_channel_offset,
275                params.feature_channel_count,
276                image_index,
277            )
278        };
279        Ok(())
280    }
281
282    /// Write bytes into the image from a caller-provided buffer.
283    pub fn write_bytes(
284        &self,
285        src: &[u8],
286        data_layout: usize,
287        bytes_per_row: usize,
288        region: ImageRegion,
289        params: ImageReadWriteParams,
290        image_index: usize,
291    ) -> Result<()> {
292        let expected = required_bytes(data_layout, bytes_per_row, region, params);
293        if src.len() < expected {
294            return Err(Error::InvalidLength {
295                expected,
296                actual: src.len(),
297            });
298        }
299
300        // SAFETY: `src` is valid for reads of at least `expected` bytes and all handles are valid.
301        let _ = unsafe {
302            ffi::mps_image_write_bytes(
303                self.ptr,
304                src.as_ptr().cast(),
305                data_layout,
306                bytes_per_row,
307                region.x,
308                region.y,
309                region.z,
310                region.width,
311                region.height,
312                region.depth,
313                params.feature_channel_offset,
314                params.feature_channel_count,
315                image_index,
316            )
317        };
318        Ok(())
319    }
320
321    /// Read the first image slice as tightly packed float32 HWC data.
322    pub fn read_f32(&self) -> Result<Vec<f32>> {
323        let len = self.width() * self.height() * self.feature_channels();
324        let mut data = vec![0.0_f32; len];
325        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
326        // SAFETY: `data` is a contiguous `Vec<f32>` with exactly `len * size_of::<f32>()` bytes.
327        let bytes = unsafe {
328            core::slice::from_raw_parts_mut(
329                data.as_mut_ptr().cast::<u8>(),
330                core::mem::size_of_val(data.as_slice()),
331            )
332        };
333        self.read_bytes(
334            bytes,
335            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
336            bytes_per_row,
337            self.whole_region(),
338            ImageReadWriteParams::all(self.feature_channels()),
339            0,
340        )?;
341        Ok(data)
342    }
343
344    /// Write tightly packed float32 HWC data into the first image slice.
345    pub fn write_f32(&self, data: &[f32]) -> Result<()> {
346        let expected = self.width() * self.height() * self.feature_channels();
347        if data.len() != expected {
348            return Err(Error::InvalidLength {
349                expected: expected * core::mem::size_of::<f32>(),
350                actual: core::mem::size_of_val(data),
351            });
352        }
353
354        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
355        // SAFETY: `data` is a contiguous slice of `f32`, which may be viewed as bytes.
356        let bytes = unsafe {
357            core::slice::from_raw_parts(data.as_ptr().cast::<u8>(), core::mem::size_of_val(data))
358        };
359        self.write_bytes(
360            bytes,
361            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
362            bytes_per_row,
363            self.whole_region(),
364            ImageReadWriteParams::all(self.feature_channels()),
365            0,
366        )
367    }
368}
369
370fn required_bytes(
371    data_layout: usize,
372    bytes_per_row: usize,
373    region: ImageRegion,
374    params: ImageReadWriteParams,
375) -> usize {
376    let rows = region.height.saturating_mul(region.depth);
377    let base = bytes_per_row.saturating_mul(rows);
378    if data_layout == image_layout::FEATURE_CHANNELSxHEIGHTxWIDTH {
379        base.saturating_mul(params.feature_channel_count.max(1))
380    } else {
381        base
382    }
383}