use crate::error::{Error, Result};
use crate::ffi;
use apple_metal::{storage_mode, texture_usage, MetalDevice, MetalTexture};
use core::ffi::c_void;
use core::ptr;
pub mod feature_channel_format {
pub const NONE: usize = 0;
pub const UNORM8: usize = 1;
pub const UNORM16: usize = 2;
pub const FLOAT16: usize = 3;
pub const FLOAT32: usize = 4;
}
#[allow(non_upper_case_globals)]
pub mod image_layout {
pub const HEIGHTxWIDTHxFEATURE_CHANNELS: usize = 0;
pub const FEATURE_CHANNELSxHEIGHTxWIDTH: usize = 1;
}
pub mod image_edge_mode {
pub const ZERO: usize = 0;
pub const CLAMP: usize = 1;
}
pub mod kernel_options {
pub const NONE: u32 = 0;
pub const SKIP_API_VALIDATION: u32 = 1 << 0;
pub const ALLOW_REDUCED_PRECISION: u32 = 1 << 1;
pub const DISABLE_INTERNAL_TILING: u32 = 1 << 2;
pub const INSERT_DEBUG_GROUPS: u32 = 1 << 3;
pub const VERBOSE: u32 = 1 << 4;
}
#[derive(Debug, Clone, Copy)]
pub struct ImageDescriptor {
pub channel_format: usize,
pub width: usize,
pub height: usize,
pub feature_channels: usize,
pub number_of_images: usize,
pub usage: usize,
pub storage_mode: usize,
}
impl ImageDescriptor {
#[must_use]
pub const fn new(
width: usize,
height: usize,
feature_channels: usize,
channel_format: usize,
) -> Self {
Self {
channel_format,
width,
height,
feature_channels,
number_of_images: 1,
usage: texture_usage::SHADER_READ | texture_usage::SHADER_WRITE,
storage_mode: storage_mode::MANAGED,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ImageRegion {
pub x: usize,
pub y: usize,
pub z: usize,
pub width: usize,
pub height: usize,
pub depth: usize,
}
impl ImageRegion {
#[must_use]
pub const fn new(
x: usize,
y: usize,
z: usize,
width: usize,
height: usize,
depth: usize,
) -> Self {
Self {
x,
y,
z,
width,
height,
depth,
}
}
#[must_use]
pub const fn whole(width: usize, height: usize) -> Self {
Self::new(0, 0, 0, width, height, 1)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ImageReadWriteParams {
pub feature_channel_offset: usize,
pub feature_channel_count: usize,
}
impl ImageReadWriteParams {
#[must_use]
pub const fn new(feature_channel_offset: usize, feature_channel_count: usize) -> Self {
Self {
feature_channel_offset,
feature_channel_count,
}
}
#[must_use]
pub const fn all(feature_channels: usize) -> Self {
Self::new(0, feature_channels)
}
}
pub struct Image {
ptr: *mut c_void,
}
unsafe impl Send for Image {}
unsafe impl Sync for Image {}
impl Drop for Image {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::mps_object_release(self.ptr) };
self.ptr = ptr::null_mut();
}
}
}
impl Image {
#[must_use]
pub fn new(device: &MetalDevice, descriptor: ImageDescriptor) -> Option<Self> {
let ptr = unsafe {
ffi::mps_image_new_with_descriptor(
device.as_ptr(),
descriptor.channel_format,
descriptor.width,
descriptor.height,
descriptor.feature_channels,
descriptor.number_of_images,
descriptor.usage,
descriptor.storage_mode,
)
};
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
#[must_use]
pub fn from_texture(texture: &MetalTexture, feature_channels: usize) -> Option<Self> {
let ptr = unsafe { ffi::mps_image_new_with_texture(texture.as_ptr(), feature_channels) };
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
#[must_use]
pub const fn as_ptr(&self) -> *mut c_void {
self.ptr
}
#[must_use]
pub(crate) const unsafe fn from_raw(ptr: *mut c_void) -> Self {
Self { ptr }
}
#[must_use]
pub fn width(&self) -> usize {
unsafe { ffi::mps_image_width(self.ptr) }
}
#[must_use]
pub fn height(&self) -> usize {
unsafe { ffi::mps_image_height(self.ptr) }
}
#[must_use]
pub fn feature_channels(&self) -> usize {
unsafe { ffi::mps_image_feature_channels(self.ptr) }
}
#[must_use]
pub fn number_of_images(&self) -> usize {
unsafe { ffi::mps_image_number_of_images(self.ptr) }
}
#[must_use]
pub fn pixel_size(&self) -> usize {
unsafe { ffi::mps_image_pixel_size(self.ptr) }
}
#[must_use]
pub fn pixel_format(&self) -> usize {
unsafe { ffi::mps_image_pixel_format(self.ptr) }
}
#[must_use]
pub fn whole_region(&self) -> ImageRegion {
ImageRegion::whole(self.width(), self.height())
}
pub fn read_bytes(
&self,
dst: &mut [u8],
data_layout: usize,
bytes_per_row: usize,
region: ImageRegion,
params: ImageReadWriteParams,
image_index: usize,
) -> Result<()> {
let expected = required_bytes(data_layout, bytes_per_row, region, params);
if dst.len() < expected {
return Err(Error::InvalidLength {
expected,
actual: dst.len(),
});
}
let _ = unsafe {
ffi::mps_image_read_bytes(
self.ptr,
dst.as_mut_ptr().cast(),
data_layout,
bytes_per_row,
region.x,
region.y,
region.z,
region.width,
region.height,
region.depth,
params.feature_channel_offset,
params.feature_channel_count,
image_index,
)
};
Ok(())
}
pub fn write_bytes(
&self,
src: &[u8],
data_layout: usize,
bytes_per_row: usize,
region: ImageRegion,
params: ImageReadWriteParams,
image_index: usize,
) -> Result<()> {
let expected = required_bytes(data_layout, bytes_per_row, region, params);
if src.len() < expected {
return Err(Error::InvalidLength {
expected,
actual: src.len(),
});
}
let _ = unsafe {
ffi::mps_image_write_bytes(
self.ptr,
src.as_ptr().cast(),
data_layout,
bytes_per_row,
region.x,
region.y,
region.z,
region.width,
region.height,
region.depth,
params.feature_channel_offset,
params.feature_channel_count,
image_index,
)
};
Ok(())
}
pub fn read_f32(&self) -> Result<Vec<f32>> {
let len = self.width() * self.height() * self.feature_channels();
let mut data = vec![0.0_f32; len];
let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
let bytes = unsafe {
core::slice::from_raw_parts_mut(
data.as_mut_ptr().cast::<u8>(),
core::mem::size_of_val(data.as_slice()),
)
};
self.read_bytes(
bytes,
image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
bytes_per_row,
self.whole_region(),
ImageReadWriteParams::all(self.feature_channels()),
0,
)?;
Ok(data)
}
pub fn write_f32(&self, data: &[f32]) -> Result<()> {
let expected = self.width() * self.height() * self.feature_channels();
if data.len() != expected {
return Err(Error::InvalidLength {
expected: expected * core::mem::size_of::<f32>(),
actual: core::mem::size_of_val(data),
});
}
let bytes_per_row = self.width() * self.feature_channels() * core::mem::size_of::<f32>();
let bytes = unsafe {
core::slice::from_raw_parts(data.as_ptr().cast::<u8>(), core::mem::size_of_val(data))
};
self.write_bytes(
bytes,
image_layout::HEIGHTxWIDTHxFEATURE_CHANNELS,
bytes_per_row,
self.whole_region(),
ImageReadWriteParams::all(self.feature_channels()),
0,
)
}
}
fn required_bytes(
data_layout: usize,
bytes_per_row: usize,
region: ImageRegion,
params: ImageReadWriteParams,
) -> usize {
let rows = region.height.saturating_mul(region.depth);
let base = bytes_per_row.saturating_mul(rows);
if data_layout == image_layout::FEATURE_CHANNELSxHEIGHTxWIDTH {
base.saturating_mul(params.feature_channel_count.max(1))
} else {
base
}
}