1use crate::error::{Error, Result};
2use crate::ffi;
3use apple_metal::{storage_mode, texture_usage, MetalDevice, MetalTexture};
4use core::ffi::c_void;
5use core::ptr;
6
7pub mod feature_channel_format {
9 pub const NONE: usize = 0;
10 pub const UNORM8: usize = 1;
11 pub const UNORM16: usize = 2;
12 pub const FLOAT16: usize = 3;
13 pub const FLOAT32: usize = 4;
14}
15
16#[allow(non_upper_case_globals)]
18pub mod image_layout {
19 pub const HEIGHTxWIDTHxFEATURE_CHANNELS: usize = 0;
20 pub const FEATURE_CHANNELSxHEIGHTxWIDTH: usize = 1;
21}
22
23pub mod image_edge_mode {
25 pub const ZERO: usize = 0;
26 pub const CLAMP: usize = 1;
27}
28
29pub mod kernel_options {
31 pub const NONE: u32 = 0;
32 pub const SKIP_API_VALIDATION: u32 = 1 << 0;
33 pub const ALLOW_REDUCED_PRECISION: u32 = 1 << 1;
34 pub const DISABLE_INTERNAL_TILING: u32 = 1 << 2;
35 pub const INSERT_DEBUG_GROUPS: u32 = 1 << 3;
36 pub const VERBOSE: u32 = 1 << 4;
37}
38
39#[derive(Debug, Clone, Copy)]
41pub struct ImageDescriptor {
42 pub channel_format: usize,
43 pub width: usize,
44 pub height: usize,
45 pub feature_channels: usize,
46 pub number_of_images: usize,
47 pub usage: usize,
48 pub storage_mode: usize,
49}
50
51impl ImageDescriptor {
52 #[must_use]
54 pub const fn new(
55 width: usize,
56 height: usize,
57 feature_channels: usize,
58 channel_format: usize,
59 ) -> Self {
60 Self {
61 channel_format,
62 width,
63 height,
64 feature_channels,
65 number_of_images: 1,
66 usage: texture_usage::SHADER_READ | texture_usage::SHADER_WRITE,
67 storage_mode: storage_mode::MANAGED,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy)]
74pub struct ImageRegion {
75 pub x: usize,
76 pub y: usize,
77 pub z: usize,
78 pub width: usize,
79 pub height: usize,
80 pub depth: usize,
81}
82
83impl ImageRegion {
84 #[must_use]
86 pub const fn new(
87 x: usize,
88 y: usize,
89 z: usize,
90 width: usize,
91 height: usize,
92 depth: usize,
93 ) -> Self {
94 Self {
95 x,
96 y,
97 z,
98 width,
99 height,
100 depth,
101 }
102 }
103
104 #[must_use]
106 pub const fn whole(width: usize, height: usize) -> Self {
107 Self::new(0, 0, 0, width, height, 1)
108 }
109}
110
111#[derive(Debug, Clone, Copy)]
113pub struct ImageReadWriteParams {
114 pub feature_channel_offset: usize,
115 pub feature_channel_count: usize,
116}
117
118impl ImageReadWriteParams {
119 #[must_use]
121 pub const fn new(feature_channel_offset: usize, feature_channel_count: usize) -> Self {
122 Self {
123 feature_channel_offset,
124 feature_channel_count,
125 }
126 }
127
128 #[must_use]
130 pub const fn all(feature_channels: usize) -> Self {
131 Self::new(0, feature_channels)
132 }
133}
134
135pub struct Image {
137 ptr: *mut c_void,
138}
139
140unsafe impl Send for Image {}
141unsafe impl Sync for Image {}
142
143impl Drop for Image {
144 fn drop(&mut self) {
145 if !self.ptr.is_null() {
146 unsafe { ffi::mps_object_release(self.ptr) };
148 self.ptr = ptr::null_mut();
149 }
150 }
151}
152
153impl Image {
154 #[must_use]
156 pub fn new(device: &MetalDevice, descriptor: ImageDescriptor) -> Option<Self> {
157 let ptr = unsafe {
159 ffi::mps_image_new_with_descriptor(
160 device.as_ptr(),
161 descriptor.channel_format,
162 descriptor.width,
163 descriptor.height,
164 descriptor.feature_channels,
165 descriptor.number_of_images,
166 descriptor.usage,
167 descriptor.storage_mode,
168 )
169 };
170 if ptr.is_null() {
171 None
172 } else {
173 Some(Self { ptr })
174 }
175 }
176
177 #[must_use]
179 pub fn from_texture(texture: &MetalTexture, feature_channels: usize) -> Option<Self> {
180 let ptr = unsafe { ffi::mps_image_new_with_texture(texture.as_ptr(), feature_channels) };
182 if ptr.is_null() {
183 None
184 } else {
185 Some(Self { ptr })
186 }
187 }
188
189 #[must_use]
191 pub const fn as_ptr(&self) -> *mut c_void {
192 self.ptr
193 }
194
195 #[must_use]
196 pub(crate) const unsafe fn from_raw(ptr: *mut c_void) -> Self {
197 Self { ptr }
198 }
199
200 #[must_use]
202 pub fn width(&self) -> usize {
203 unsafe { ffi::mps_image_width(self.ptr) }
205 }
206
207 #[must_use]
209 pub fn height(&self) -> usize {
210 unsafe { ffi::mps_image_height(self.ptr) }
212 }
213
214 #[must_use]
216 pub fn feature_channels(&self) -> usize {
217 unsafe { ffi::mps_image_feature_channels(self.ptr) }
219 }
220
221 #[must_use]
223 pub fn number_of_images(&self) -> usize {
224 unsafe { ffi::mps_image_number_of_images(self.ptr) }
226 }
227
228 #[must_use]
230 pub fn pixel_size(&self) -> usize {
231 unsafe { ffi::mps_image_pixel_size(self.ptr) }
233 }
234
235 #[must_use]
237 pub fn pixel_format(&self) -> usize {
238 unsafe { ffi::mps_image_pixel_format(self.ptr) }
240 }
241
242 #[must_use]
244 pub fn whole_region(&self) -> ImageRegion {
245 ImageRegion::whole(self.width(), self.height())
246 }
247
248 pub fn read_bytes(
250 &self,
251 dst: &mut [u8],
252 data_layout: usize,
253 bytes_per_row: usize,
254 region: ImageRegion,
255 params: ImageReadWriteParams,
256 image_index: usize,
257 ) -> Result<()> {
258 let expected = required_bytes(data_layout, bytes_per_row, region, params);
259 if dst.len() < expected {
260 return Err(Error::InvalidLength {
261 expected,
262 actual: dst.len(),
263 });
264 }
265
266 let _ = unsafe {
268 ffi::mps_image_read_bytes(
269 self.ptr,
270 dst.as_mut_ptr().cast(),
271 data_layout,
272 bytes_per_row,
273 region.x,
274 region.y,
275 region.z,
276 region.width,
277 region.height,
278 region.depth,
279 params.feature_channel_offset,
280 params.feature_channel_count,
281 image_index,
282 )
283 };
284 Ok(())
285 }
286
287 pub fn write_bytes(
289 &self,
290 src: &[u8],
291 data_layout: usize,
292 bytes_per_row: usize,
293 region: ImageRegion,
294 params: ImageReadWriteParams,
295 image_index: usize,
296 ) -> Result<()> {
297 let expected = required_bytes(data_layout, bytes_per_row, region, params);
298 if src.len() < expected {
299 return Err(Error::InvalidLength {
300 expected,
301 actual: src.len(),
302 });
303 }
304
305 let _ = unsafe {
307 ffi::mps_image_write_bytes(
308 self.ptr,
309 src.as_ptr().cast(),
310 data_layout,
311 bytes_per_row,
312 region.x,
313 region.y,
314 region.z,
315 region.width,
316 region.height,
317 region.depth,
318 params.feature_channel_offset,
319 params.feature_channel_count,
320 image_index,
321 )
322 };
323 Ok(())
324 }
325
326 pub fn read_f32(&self) -> Result<Vec<f32>> {
328 let len = self.width() * self.height() * self.feature_channels();
329 let mut data = vec![0.0_f32; len];
330 let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
331 let bytes = unsafe {
333 core::slice::from_raw_parts_mut(
334 data.as_mut_ptr().cast::<u8>(),
335 core::mem::size_of_val(data.as_slice()),
336 )
337 };
338 self.read_bytes(
339 bytes,
340 image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
341 bytes_per_row,
342 self.whole_region(),
343 ImageReadWriteParams::all(self.feature_channels()),
344 0,
345 )?;
346 Ok(data)
347 }
348
349 pub fn write_f32(&self, data: &[f32]) -> Result<()> {
351 let expected = self.width() * self.height() * self.feature_channels();
352 if data.len() != expected {
353 return Err(Error::InvalidLength {
354 expected: expected * core::mem::size_of::<f32>(),
355 actual: core::mem::size_of_val(data),
356 });
357 }
358
359 let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
360 let bytes = unsafe {
362 core::slice::from_raw_parts(data.as_ptr().cast::<u8>(), core::mem::size_of_val(data))
363 };
364 self.write_bytes(
365 bytes,
366 image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
367 bytes_per_row,
368 self.whole_region(),
369 ImageReadWriteParams::all(self.feature_channels()),
370 0,
371 )
372 }
373}
374
375fn required_bytes(
376 data_layout: usize,
377 bytes_per_row: usize,
378 region: ImageRegion,
379 params: ImageReadWriteParams,
380) -> usize {
381 let rows = region.height.saturating_mul(region.depth);
382 let base = bytes_per_row.saturating_mul(rows);
383 if data_layout == image_layout::FEATURE_CHANNELSxHEIGHTxWIDTH {
384 base.saturating_mul(params.feature_channel_count.max(1))
385 } else {
386 base
387 }
388}