Skip to main content

apple_mps/
image.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use apple_metal::{storage_mode, texture_usage, MetalDevice, MetalTexture};
4use core::ffi::c_void;
5use core::ptr;
6
7/// `MPSImageFeatureChannelFormat` constants.
8pub mod feature_channel_format {
9    pub const NONE: usize = 0;
10    pub const UNORM8: usize = 1;
11    pub const UNORM16: usize = 2;
12    pub const FLOAT16: usize = 3;
13    pub const FLOAT32: usize = 4;
14}
15
16/// `MPSDataLayout` constants.
17#[allow(non_upper_case_globals)]
18pub mod image_layout {
19    pub const HEIGHTxWIDTHxFEATURE_CHANNELS: usize = 0;
20    pub const FEATURE_CHANNELSxHEIGHTxWIDTH: usize = 1;
21}
22
23/// `MPSImageEdgeMode` constants.
24pub mod image_edge_mode {
25    pub const ZERO: usize = 0;
26    pub const CLAMP: usize = 1;
27}
28
29/// `MPSKernelOptions` constants.
30pub mod kernel_options {
31    pub const NONE: u32 = 0;
32    pub const SKIP_API_VALIDATION: u32 = 1 << 0;
33    pub const ALLOW_REDUCED_PRECISION: u32 = 1 << 1;
34    pub const DISABLE_INTERNAL_TILING: u32 = 1 << 2;
35    pub const INSERT_DEBUG_GROUPS: u32 = 1 << 3;
36    pub const VERBOSE: u32 = 1 << 4;
37}
38
39/// Plain-Rust configuration for building a `MPSImageDescriptor` on the Swift side.
40#[derive(Debug, Clone, Copy)]
41pub struct ImageDescriptor {
42    pub channel_format: usize,
43    pub width: usize,
44    pub height: usize,
45    pub feature_channels: usize,
46    pub number_of_images: usize,
47    pub usage: usize,
48    pub storage_mode: usize,
49}
50
51impl ImageDescriptor {
52    /// Create a single-image descriptor with sensible defaults for read/write image processing.
53    #[must_use]
54    pub const fn new(
55        width: usize,
56        height: usize,
57        feature_channels: usize,
58        channel_format: usize,
59    ) -> Self {
60        Self {
61            channel_format,
62            width,
63            height,
64            feature_channels,
65            number_of_images: 1,
66            usage: texture_usage::SHADER_READ | texture_usage::SHADER_WRITE,
67            storage_mode: storage_mode::MANAGED,
68        }
69    }
70}
71
72/// Rectangular region used for image transfer or clip-rect configuration.
73#[derive(Debug, Clone, Copy)]
74pub struct ImageRegion {
75    pub x: usize,
76    pub y: usize,
77    pub z: usize,
78    pub width: usize,
79    pub height: usize,
80    pub depth: usize,
81}
82
83impl ImageRegion {
84    /// Construct an arbitrary region.
85    #[must_use]
86    pub const fn new(
87        x: usize,
88        y: usize,
89        z: usize,
90        width: usize,
91        height: usize,
92        depth: usize,
93    ) -> Self {
94        Self {
95            x,
96            y,
97            z,
98            width,
99            height,
100            depth,
101        }
102    }
103
104    /// Region covering the full first image slice.
105    #[must_use]
106    pub const fn whole(width: usize, height: usize) -> Self {
107        Self::new(0, 0, 0, width, height, 1)
108    }
109}
110
111/// `MPSImageReadWriteParams` values.
112#[derive(Debug, Clone, Copy)]
113pub struct ImageReadWriteParams {
114    pub feature_channel_offset: usize,
115    pub feature_channel_count: usize,
116}
117
118impl ImageReadWriteParams {
119    /// Create a parameter block describing the feature-channel window to transfer.
120    #[must_use]
121    pub const fn new(feature_channel_offset: usize, feature_channel_count: usize) -> Self {
122        Self {
123            feature_channel_offset,
124            feature_channel_count,
125        }
126    }
127
128    /// Transfer all feature channels starting at zero.
129    #[must_use]
130    pub const fn all(feature_channels: usize) -> Self {
131        Self::new(0, feature_channels)
132    }
133}
134
135/// Safe owner for an Objective-C `MPSImage`.
136pub struct Image {
137    ptr: *mut c_void,
138}
139
140// SAFETY: MPSImage pointers are thread-safe Objective-C objects.
141unsafe impl Send for Image {}
142// SAFETY: MPSImage pointers are thread-safe Objective-C objects.
143unsafe impl Sync for Image {}
144
145impl Drop for Image {
146    fn drop(&mut self) {
147        if !self.ptr.is_null() {
148            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
149            unsafe { ffi::mps_object_release(self.ptr) };
150            self.ptr = ptr::null_mut();
151        }
152    }
153}
154
155impl Image {
156    /// Allocate a lazily backed `MPSImage` on `device`.
157    #[must_use]
158    pub fn new(device: &MetalDevice, descriptor: ImageDescriptor) -> Option<Self> {
159        // SAFETY: All pointers originate from safe wrappers and the scalar arguments are POD.
160        let ptr = unsafe {
161            ffi::mps_image_new_with_descriptor(
162                device.as_ptr(),
163                descriptor.channel_format,
164                descriptor.width,
165                descriptor.height,
166                descriptor.feature_channels,
167                descriptor.number_of_images,
168                descriptor.usage,
169                descriptor.storage_mode,
170            )
171        };
172        if ptr.is_null() {
173            None
174        } else {
175            Some(Self { ptr })
176        }
177    }
178
179    /// Wrap an existing Metal texture in an `MPSImage`.
180    #[must_use]
181    pub fn from_texture(texture: &MetalTexture, feature_channels: usize) -> Option<Self> {
182        // SAFETY: `texture` is a valid `MTLTexture` pointer from `apple-metal`.
183        let ptr = unsafe { ffi::mps_image_new_with_texture(texture.as_ptr(), feature_channels) };
184        if ptr.is_null() {
185            None
186        } else {
187            Some(Self { ptr })
188        }
189    }
190
191    /// Raw `MPSImage` pointer.
192    #[must_use]
193    pub const fn as_ptr(&self) -> *mut c_void {
194        self.ptr
195    }
196
197    #[must_use]
198    pub(crate) const unsafe fn from_raw(ptr: *mut c_void) -> Self {
199        // SAFETY: Caller must ensure `ptr` is a valid +1 retained MPSImage pointer.
200        // SAFETY: Caller must ensure `ptr` is a valid +1 retained MPSImage pointer.
201        Self { ptr }
202    }
203
204    /// Image width in pixels.
205    #[must_use]
206    pub fn width(&self) -> usize {
207        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
208        unsafe { ffi::mps_image_width(self.ptr) }
209    }
210
211    /// Image height in pixels.
212    #[must_use]
213    pub fn height(&self) -> usize {
214        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
215        unsafe { ffi::mps_image_height(self.ptr) }
216    }
217
218    /// Number of feature channels per pixel.
219    #[must_use]
220    pub fn feature_channels(&self) -> usize {
221        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
222        unsafe { ffi::mps_image_feature_channels(self.ptr) }
223    }
224
225    /// Number of images stored in the backing texture array.
226    #[must_use]
227    pub fn number_of_images(&self) -> usize {
228        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
229        unsafe { ffi::mps_image_number_of_images(self.ptr) }
230    }
231
232    /// Bytes between neighboring pixels in storage order.
233    #[must_use]
234    pub fn pixel_size(&self) -> usize {
235        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
236        unsafe { ffi::mps_image_pixel_size(self.ptr) }
237    }
238
239    /// Underlying `MTLPixelFormat` raw value.
240    #[must_use]
241    pub fn pixel_format(&self) -> usize {
242        // SAFETY: `self.ptr` is a valid `MPSImage` pointer while `self` is alive.
243        unsafe { ffi::mps_image_pixel_format(self.ptr) }
244    }
245
246    /// Convenience region covering the full first image.
247    #[must_use]
248    pub fn whole_region(&self) -> ImageRegion {
249        ImageRegion::whole(self.width(), self.height())
250    }
251
252    /// Read bytes out of the image into a caller-provided buffer.
253    pub fn read_bytes(
254        &self,
255        dst: &mut [u8],
256        data_layout: usize,
257        bytes_per_row: usize,
258        region: ImageRegion,
259        params: ImageReadWriteParams,
260        image_index: usize,
261    ) -> Result<()> {
262        let expected = required_bytes(data_layout, bytes_per_row, region, params);
263        if dst.len() < expected {
264            return Err(Error::InvalidLength {
265                expected,
266                actual: dst.len(),
267            });
268        }
269
270        // SAFETY: `dst` is valid for writes of at least `expected` bytes and all handles are valid.
271        let _ = unsafe {
272            ffi::mps_image_read_bytes(
273                self.ptr,
274                dst.as_mut_ptr().cast(),
275                data_layout,
276                bytes_per_row,
277                region.x,
278                region.y,
279                region.z,
280                region.width,
281                region.height,
282                region.depth,
283                params.feature_channel_offset,
284                params.feature_channel_count,
285                image_index,
286            )
287        };
288        Ok(())
289    }
290
291    /// Write bytes into the image from a caller-provided buffer.
292    pub fn write_bytes(
293        &self,
294        src: &[u8],
295        data_layout: usize,
296        bytes_per_row: usize,
297        region: ImageRegion,
298        params: ImageReadWriteParams,
299        image_index: usize,
300    ) -> Result<()> {
301        let expected = required_bytes(data_layout, bytes_per_row, region, params);
302        if src.len() < expected {
303            return Err(Error::InvalidLength {
304                expected,
305                actual: src.len(),
306            });
307        }
308
309        // SAFETY: `src` is valid for reads of at least `expected` bytes and all handles are valid.
310        let _ = unsafe {
311            ffi::mps_image_write_bytes(
312                self.ptr,
313                src.as_ptr().cast(),
314                data_layout,
315                bytes_per_row,
316                region.x,
317                region.y,
318                region.z,
319                region.width,
320                region.height,
321                region.depth,
322                params.feature_channel_offset,
323                params.feature_channel_count,
324                image_index,
325            )
326        };
327        Ok(())
328    }
329
330    /// Read the first image slice as tightly packed float32 HWC data.
331    pub fn read_f32(&self) -> Result<Vec<f32>> {
332        let len = self.width() * self.height() * self.feature_channels();
333        let mut data = vec![0.0_f32; len];
334        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
335        // SAFETY: `data` is a contiguous `Vec<f32>` with exactly `len * size_of::<f32>()` bytes.
336        let bytes = unsafe {
337            core::slice::from_raw_parts_mut(
338                data.as_mut_ptr().cast::<u8>(),
339                core::mem::size_of_val(data.as_slice()),
340            )
341        };
342        self.read_bytes(
343            bytes,
344            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
345            bytes_per_row,
346            self.whole_region(),
347            ImageReadWriteParams::all(self.feature_channels()),
348            0,
349        )?;
350        Ok(data)
351    }
352
353    /// Write tightly packed float32 HWC data into the first image slice.
354    pub fn write_f32(&self, data: &[f32]) -> Result<()> {
355        let expected = self.width() * self.height() * self.feature_channels();
356        if data.len() != expected {
357            return Err(Error::InvalidLength {
358                expected: expected * core::mem::size_of::<f32>(),
359                actual: core::mem::size_of_val(data),
360            });
361        }
362
363        let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
364        // SAFETY: `data` is a contiguous slice of `f32`, which may be viewed as bytes.
365        let bytes = unsafe {
366            core::slice::from_raw_parts(data.as_ptr().cast::<u8>(), core::mem::size_of_val(data))
367        };
368        self.write_bytes(
369            bytes,
370            image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
371            bytes_per_row,
372            self.whole_region(),
373            ImageReadWriteParams::all(self.feature_channels()),
374            0,
375        )
376    }
377}
378
379pub use crate::generated::image::*;
380
381fn required_bytes(
382    data_layout: usize,
383    bytes_per_row: usize,
384    region: ImageRegion,
385    params: ImageReadWriteParams,
386) -> usize {
387    let rows = region.height.saturating_mul(region.depth);
388    let base = bytes_per_row.saturating_mul(rows);
389    if data_layout == image_layout::FEATURE_CHANNELSxHEIGHTxWIDTH {
390        base.saturating_mul(params.feature_channel_count.max(1))
391    } else {
392        base
393    }
394}