use metal::foreign_types::ForeignType;
use metal::NSUInteger;
use metal::{CommandBuffer, Device};
use objc2::msg_send;
use objc2::runtime::AnyObject;
use crate::metal::{
buffer::MetalBuffer,
error::{MetalError, Result},
mps::{create_conv_descriptor, create_image_descriptor, MPSDataType},
};
#[allow(dead_code)]
pub struct MPSConv2d {
conv: *mut AnyObject,
output: MetalBuffer,
params: Conv2dParams,
device: Device,
}
#[derive(Debug, Clone)]
pub struct Conv2dParams {
pub batch_size: usize,
pub in_channels: usize,
pub out_channels: usize,
pub height: usize,
pub width: usize,
pub kernel_height: usize,
pub kernel_width: usize,
pub stride_height: usize,
pub stride_width: usize,
pub padding_height: usize,
pub padding_width: usize,
pub dilation_height: usize,
pub dilation_width: usize,
pub groups: usize,
}
impl MPSConv2d {
pub fn new(
device: &Device,
params: Conv2dParams,
weights: &MetalBuffer,
bias: Option<&MetalBuffer>,
) -> Result<Self> {
unsafe {
if params.in_channels % params.groups != 0 || params.out_channels % params.groups != 0 {
return Err(MetalError::BackendError(
"Invalid group convolution parameters".to_string(),
));
}
let out_height = (params.height + 2 * params.padding_height - params.kernel_height)
/ params.stride_height
+ 1;
let out_width = (params.width + 2 * params.padding_width - params.kernel_width)
/ params.stride_width
+ 1;
let output_shape = vec![
params.batch_size,
params.out_channels,
out_height,
out_width,
];
let output = MetalBuffer::zeros(
&torsh_core::Shape::from(output_shape),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
let desc = create_conv_descriptor(
params.kernel_height,
params.kernel_width,
params.in_channels / params.groups,
params.out_channels,
);
let _: () = msg_send![desc, setStrideInPixelsX: params.stride_width as NSUInteger];
let _: () = msg_send![desc, setStrideInPixelsY: params.stride_height as NSUInteger];
let _: () = msg_send![desc, setGroups: params.groups as NSUInteger];
let class = objc2::class!(MPSCNNConvolution);
let conv: *mut AnyObject = msg_send![class, alloc];
let weights_ptr = weights.buffer().contents() as *const f32;
let bias_ptr = if let Some(b) = bias {
b.buffer().contents() as *const f32
} else {
std::ptr::null()
};
let conv: *mut AnyObject = msg_send![conv,
initWithDevice: device.as_ptr() as *mut AnyObject,
convolutionDescriptor: desc,
kernelWeights: weights_ptr,
biasTerms: bias_ptr,
flags: 0 as NSUInteger ];
let _: () = msg_send![conv, setPaddingLeft: params.padding_width as NSUInteger];
let _: () = msg_send![conv, setPaddingRight: params.padding_width as NSUInteger];
let _: () = msg_send![conv, setPaddingTop: params.padding_height as NSUInteger];
let _: () = msg_send![conv, setPaddingBottom: params.padding_height as NSUInteger];
Ok(Self {
conv,
output,
params: params.clone(),
device: device.clone(),
})
}
}
pub fn encode_conv(&self, command_buffer: &CommandBuffer, _input: &MetalBuffer) -> Result<()> {
unsafe {
let class = objc2::class!(MPSImage);
let in_desc = create_image_descriptor(
self.params.width,
self.params.height,
self.params.in_channels,
MPSDataType::Float32,
);
let input_image: *mut AnyObject = msg_send![class, alloc];
let input_image: *mut AnyObject = msg_send![input_image,
initWithDevice: self.device.as_ptr() as *mut AnyObject,
imageDescriptor: in_desc
];
let out_desc = create_image_descriptor(
(self.params.width + 2 * self.params.padding_width - self.params.kernel_width)
/ self.params.stride_width
+ 1,
(self.params.height + 2 * self.params.padding_height - self.params.kernel_height)
/ self.params.stride_height
+ 1,
self.params.out_channels,
MPSDataType::Float32,
);
let output_image: *mut AnyObject = msg_send![class, alloc];
let output_image: *mut AnyObject = msg_send![output_image,
initWithDevice: self.device.as_ptr() as *mut AnyObject,
imageDescriptor: out_desc
];
let _: () = msg_send![self.conv,
encodeToCommandBuffer: command_buffer.as_ptr() as *mut AnyObject,
sourceImage: input_image,
destinationImage: output_image
];
Ok(())
}
}
pub fn output(&self) -> &MetalBuffer {
&self.output
}
}
impl Drop for MPSConv2d {
fn drop(&mut self) {
unsafe {
let _: () = msg_send![self.conv, release];
}
}
}