use super::*;
use crate::gpu::coreml::common::coreml_feature;
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive};
pub trait CoreMLConvolution<T>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
{
fn coreml_conv2d(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>>;
fn coreml_batch_conv2d(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>>;
fn coreml_conv2d_transpose(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>>;
fn coreml_depthwise_conv2d(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>>;
}
pub struct ConvolutionOperation<T: Float> {
input: Tensor<T>,
kernel: Tensor<T>,
stride: Vec<usize>,
padding: Vec<usize>,
conv_type: ConvolutionType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConvolutionType {
Conv2D,
TransposedConv2D,
DepthwiseConv2D,
GroupedConv2D,
}
impl<T> ConvolutionOperation<T>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
{
pub fn new(
input: Tensor<T>,
kernel: Tensor<T>,
stride: Vec<usize>,
padding: Vec<usize>,
conv_type: ConvolutionType,
) -> Self {
Self {
input,
kernel,
stride,
padding,
conv_type,
}
}
fn validate_parameters(&self) -> CoreMLResult<()> {
let input_shape = self.input.shape();
let kernel_shape = self.kernel.shape();
if input_shape.len() != 4 {
return Err(error_helpers::unsupported_operation(
"Convolution input must be 4D tensor [N, C, H, W]",
));
}
if kernel_shape.len() != 4 {
return Err(error_helpers::unsupported_operation(
"Convolution kernel must be 4D tensor [out_channels, in_channels, kH, kW]",
));
}
let input_channels = input_shape[1];
let kernel_input_channels = kernel_shape[1];
match self.conv_type {
ConvolutionType::Conv2D => {
if input_channels != kernel_input_channels {
return Err(error_helpers::tensor_op_error(&format!(
"Channel mismatch: input has {} channels, kernel expects {}",
input_channels, kernel_input_channels
)));
}
}
ConvolutionType::DepthwiseConv2D => {
if kernel_input_channels != 1 {
return Err(error_helpers::unsupported_operation(
"Depthwise convolution kernel should have 1 input channel",
));
}
}
_ => {
}
}
if self.stride.len() != 2 {
return Err(error_helpers::unsupported_operation(
"Stride must have 2 elements [height, width]",
));
}
if self.padding.len() != 2 && self.padding.len() != 4 {
return Err(error_helpers::unsupported_operation(
"Padding must have 2 elements [height, width] or 4 elements [top, bottom, left, right]"
));
}
Ok(())
}
fn is_efficient_on_coreml(&self) -> bool {
let input_shape = self.input.shape();
let kernel_shape = self.kernel.shape();
let batch_size = input_shape[0];
let input_channels = input_shape[1];
let input_height = input_shape[2];
let input_width = input_shape[3];
let output_channels = kernel_shape[0];
let spatial_size = input_height * input_width;
let total_params = input_channels * output_channels * kernel_shape[2] * kernel_shape[3];
let spatial_efficient = spatial_size >= 256 && spatial_size <= 1_048_576; let channel_efficient = input_channels >= 4 && output_channels >= 4;
let batch_efficient = batch_size >= 1 && batch_size <= 32; let param_efficient = total_params >= 1024 && total_params <= 16_777_216;
spatial_efficient && channel_efficient && batch_efficient && param_efficient
}
fn calculate_output_shape(&self) -> Vec<usize> {
let input_shape = self.input.shape();
let kernel_shape = self.kernel.shape();
let batch_size = input_shape[0];
let output_channels = kernel_shape[0];
let input_height = input_shape[2];
let input_width = input_shape[3];
let kernel_height = kernel_shape[2];
let kernel_width = kernel_shape[3];
let (pad_top, pad_bottom, pad_left, pad_right) = if self.padding.len() == 2 {
(
self.padding[0],
self.padding[0],
self.padding[1],
self.padding[1],
)
} else {
(
self.padding[0],
self.padding[1],
self.padding[2],
self.padding[3],
)
};
let output_height =
(input_height + pad_top + pad_bottom - kernel_height) / self.stride[0] + 1;
let output_width = (input_width + pad_left + pad_right - kernel_width) / self.stride[1] + 1;
vec![batch_size, output_channels, output_height, output_width]
}
}
impl<T> CoreMLOperation<T> for ConvolutionOperation<T>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
{
fn execute_coreml(&self, device_id: usize) -> CoreMLResult<Tensor<T>> {
self.validate_parameters()?;
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
{
use crate::gpu::coreml::backend::CoreMLGraph;
let graph = CoreMLGraph::new(device_id)?;
return match self.conv_type {
ConvolutionType::Conv2D => {
graph.conv2d(&self.input, &self.kernel, &self.stride, &self.padding)
}
ConvolutionType::TransposedConv2D => {
Err(error_helpers::unsupported_operation(
"Transposed convolution not yet implemented in CoreML backend",
))
}
ConvolutionType::DepthwiseConv2D => {
Err(error_helpers::unsupported_operation(
"Depthwise convolution not yet implemented in CoreML backend",
))
}
ConvolutionType::GroupedConv2D => {
Err(error_helpers::unsupported_operation(
"Grouped convolution not yet implemented in CoreML backend",
))
}
};
}
#[cfg(not(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
)))]
{
Err(error_helpers::feature_disabled())
}
}
fn is_supported_by_coreml(&self) -> bool {
let type_supported = matches!(self.conv_type, ConvolutionType::Conv2D);
let params_valid = self.validate_parameters().is_ok();
let efficient = self.is_efficient_on_coreml();
type_supported && params_valid && efficient
}
fn estimated_execution_time(&self) -> Option<std::time::Duration> {
if !self.is_supported_by_coreml() {
return None;
}
let input_shape = self.input.shape();
let kernel_shape = self.kernel.shape();
let output_shape = self.calculate_output_shape();
let output_elements: usize = output_shape.iter().product();
let kernel_flops = kernel_shape[2] * kernel_shape[3] * kernel_shape[1]; let total_flops = output_elements * kernel_flops * 2;
Some(std::time::Duration::from_nanos((total_flops * 2) as u64))
}
}
impl<T> CoreMLConvolution<T> for Tensor<T>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
{
fn coreml_conv2d(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>> {
let operation = ConvolutionOperation::new(
self.clone(),
kernel.clone(),
stride.to_vec(),
padding.to_vec(),
ConvolutionType::Conv2D,
);
let executor = CoreMLExecutor::new(0)?;
executor.execute(&operation)
}
fn coreml_batch_conv2d(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>> {
self.coreml_conv2d(kernel, stride, padding)
}
fn coreml_conv2d_transpose(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>> {
let operation = ConvolutionOperation::new(
self.clone(),
kernel.clone(),
stride.to_vec(),
padding.to_vec(),
ConvolutionType::TransposedConv2D,
);
let executor = CoreMLExecutor::new(0)?;
executor.execute(&operation)
}
fn coreml_depthwise_conv2d(
&self,
kernel: &Self,
stride: &[usize],
padding: &[usize],
) -> CoreMLResult<Tensor<T>> {
let operation = ConvolutionOperation::new(
self.clone(),
kernel.clone(),
stride.to_vec(),
padding.to_vec(),
ConvolutionType::DepthwiseConv2D,
);
let executor = CoreMLExecutor::new(0)?;
executor.execute(&operation)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conv2d_parameter_validation() {
let input = Tensor::<f32>::zeros(&[1, 8, 32, 32]); let kernel = Tensor::<f32>::zeros(&[16, 8, 3, 3]); let stride = vec![1, 1];
let padding = vec![1, 1];
let operation =
ConvolutionOperation::new(input, kernel, stride, padding, ConvolutionType::Conv2D);
assert!(operation.validate_parameters().is_ok());
assert!(operation.is_supported_by_coreml());
}
#[test]
fn test_conv2d_channel_mismatch() {
let input = Tensor::<f32>::zeros(&[1, 3, 32, 32]); let kernel = Tensor::<f32>::zeros(&[16, 4, 3, 3]); let stride = vec![1, 1];
let padding = vec![1, 1];
let operation =
ConvolutionOperation::new(input, kernel, stride, padding, ConvolutionType::Conv2D);
assert!(operation.validate_parameters().is_err());
assert!(!operation.is_supported_by_coreml());
}
#[test]
fn test_small_convolution_not_efficient() {
let input = Tensor::<f32>::zeros(&[1, 1, 8, 8]); let kernel = Tensor::<f32>::zeros(&[1, 1, 3, 3]);
let stride = vec![1, 1];
let padding = vec![1, 1];
let operation =
ConvolutionOperation::new(input, kernel, stride, padding, ConvolutionType::Conv2D);
assert!(operation.validate_parameters().is_ok());
assert!(!operation.is_efficient_on_coreml()); assert!(!operation.is_supported_by_coreml());
}
#[test]
fn test_output_shape_calculation() {
let input = Tensor::<f32>::zeros(&[2, 3, 32, 32]);
let kernel = Tensor::<f32>::zeros(&[16, 3, 3, 3]);
let stride = vec![2, 2];
let padding = vec![1, 1];
let operation =
ConvolutionOperation::new(input, kernel, stride, padding, ConvolutionType::Conv2D);
let output_shape = operation.calculate_output_shape();
assert_eq!(output_shape, vec![2, 16, 16, 16]);
}
#[test]
fn test_execution_time_estimation() {
let input = Tensor::<f32>::zeros(&[1, 16, 64, 64]); let kernel = Tensor::<f32>::zeros(&[32, 16, 3, 3]);
let stride = vec![1, 1];
let padding = vec![1, 1];
let operation =
ConvolutionOperation::new(input, kernel, stride, padding, ConvolutionType::Conv2D);
let estimated_time = operation.estimated_execution_time();
assert!(estimated_time.is_some());
let time = estimated_time.unwrap();
assert!(time.as_nanos() > 0);
}
#[test]
fn test_depthwise_validation() {
let input = Tensor::<f32>::zeros(&[1, 8, 32, 32]);
let kernel = Tensor::<f32>::zeros(&[8, 1, 3, 3]); let stride = vec![1, 1];
let padding = vec![1, 1];
let operation = ConvolutionOperation::new(
input,
kernel,
stride,
padding,
ConvolutionType::DepthwiseConv2D,
);
assert!(operation.validate_parameters().is_ok());
}
}