use crate::layout::{convert_layout, DataLayout, LayoutOptimizer, OperationType};
use crate::{Result, Tensor, TensorError};
use scirs2_core::numeric::{One, Zero};
use std::collections::HashMap;
pub fn conv2d_with_layout<T>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: (usize, usize),
padding: &str,
input_layout: DataLayout,
optimizer: Option<&LayoutOptimizer>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let default_optimizer = LayoutOptimizer::default();
let layout_opt = optimizer.unwrap_or(&default_optimizer);
let device = input.device();
let preferred_layout = layout_opt.preferred_layout(device, OperationType::Convolution);
let (working_input, actual_layout) = if input_layout != preferred_layout {
let _conversion_cost = layout_opt.conversion_cost(input_layout, preferred_layout);
let operation_intensity = 3.0;
if layout_opt.should_convert(input_layout, preferred_layout, operation_intensity) {
let converted = convert_layout(input, input_layout, preferred_layout)?;
(converted, preferred_layout)
} else {
(input.clone(), input_layout)
}
} else {
(input.clone(), input_layout)
};
let result = match actual_layout {
DataLayout::NCHW => {
super::conv2d::conv2d(&working_input, weight, bias, stride, padding)
}
DataLayout::NHWC => {
let nchw_input = convert_layout(&working_input, DataLayout::NHWC, DataLayout::NCHW)?;
let nchw_result = super::conv2d::conv2d(&nchw_input, weight, bias, stride, padding)?;
convert_layout(&nchw_result, DataLayout::NCHW, DataLayout::NHWC)
}
_ => {
return Err(TensorError::unsupported_operation_simple(format!(
"Convolution not supported for layout {actual_layout:?}"
)));
}
}?;
Ok(result)
}
pub fn conv2d_auto_layout<T>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: (usize, usize),
padding: &str,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let inferred_layout = crate::layout::infer_layout(input.shape().dims(), Some(4));
conv2d_with_layout(input, weight, bias, stride, padding, inferred_layout, None)
}
pub fn conv_layout_benchmark<T>(
_input_shape: &[usize],
_weight_shape: &[usize],
_stride: (usize, usize),
_padding: &str,
device: &crate::Device,
) -> HashMap<DataLayout, f32>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static,
{
let mut benchmarks = HashMap::new();
let _optimizer = LayoutOptimizer::default();
for layout in &[DataLayout::NCHW, DataLayout::NHWC] {
let base_cost = 1.0; let layout_efficiency = match (device, layout) {
#[cfg(feature = "gpu")]
(crate::Device::Gpu(_), DataLayout::NCHW) => 1.0, #[cfg(feature = "gpu")]
(crate::Device::Gpu(_), DataLayout::NHWC) => 0.7, (crate::Device::Cpu, DataLayout::NHWC) => 1.0, (crate::Device::Cpu, DataLayout::NCHW) => 0.8, _ => 0.5, };
let memory_efficiency = match layout {
DataLayout::NCHW => 0.9, DataLayout::NHWC => 1.0, _ => 0.7,
};
let estimated_performance = base_cost * layout_efficiency * memory_efficiency;
benchmarks.insert(*layout, estimated_performance);
}
benchmarks
}
pub fn select_optimal_layout(
input_shape: &[usize],
kernel_shape: &[usize],
device: &crate::Device,
operation_intensity: f32,
) -> DataLayout {
let _optimizer = LayoutOptimizer::default();
let benchmarks = conv_layout_benchmark::<f32>(
input_shape,
kernel_shape,
(1, 1), "same", device,
);
let mut best_layout = DataLayout::NCHW;
let mut best_score = 0.0;
for (&layout, &performance) in &benchmarks {
let intensity_bonus = operation_intensity * 0.1;
let adjusted_score = performance + intensity_bonus;
if adjusted_score > best_score {
best_score = adjusted_score;
best_layout = layout;
}
}
best_layout
}
pub fn layout_conversion_cost(
from_layout: DataLayout,
to_layout: DataLayout,
tensor_size: usize,
) -> f32 {
if from_layout == to_layout {
return 0.0; }
let base_cost = tensor_size as f32 * 0.001;
let complexity_multiplier = match (from_layout, to_layout) {
(DataLayout::NCHW, DataLayout::NHWC) | (DataLayout::NHWC, DataLayout::NCHW) => 1.0,
_ => 1.5, };
base_cost * complexity_multiplier
}
pub fn should_convert_layout(
current_layout: DataLayout,
optimal_layout: DataLayout,
tensor_size: usize,
operation_count: usize,
performance_gain: f32,
) -> bool {
if current_layout == optimal_layout {
return false; }
let conversion_cost = layout_conversion_cost(current_layout, optimal_layout, tensor_size);
let operation_cost = tensor_size as f32 * 0.01;
let current_total_cost = operation_cost * operation_count as f32;
let optimal_total_cost = (operation_cost / performance_gain) * operation_count as f32;
let savings = current_total_cost - optimal_total_cost;
savings > conversion_cost
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layout_selection() {
let input_shape = [1, 64, 224, 224]; let kernel_shape = [64, 64, 3, 3];
#[cfg(feature = "gpu")]
{
let gpu_device = crate::Device::Gpu(0);
let layout = select_optimal_layout(&input_shape, &kernel_shape, &gpu_device, 2.0);
assert_eq!(layout, DataLayout::NCHW); }
let cpu_device = crate::Device::Cpu;
let layout = select_optimal_layout(&input_shape, &kernel_shape, &cpu_device, 2.0);
assert!(matches!(layout, DataLayout::NCHW | DataLayout::NHWC));
}
#[test]
fn test_conversion_cost() {
let tensor_size = 1000;
let cost = layout_conversion_cost(DataLayout::NCHW, DataLayout::NCHW, tensor_size);
assert_eq!(cost, 0.0);
let cost = layout_conversion_cost(DataLayout::NCHW, DataLayout::NHWC, tensor_size);
assert!(cost > 0.0);
}
#[test]
fn test_should_convert_logic() {
let tensor_size = 10000;
let operation_count = 100;
let high_performance_gain = 2.0; let low_performance_gain = 1.1;
let should_convert = should_convert_layout(
DataLayout::NHWC,
DataLayout::NCHW,
tensor_size,
operation_count,
high_performance_gain,
);
assert!(should_convert);
let should_convert = should_convert_layout(
DataLayout::NHWC,
DataLayout::NCHW,
tensor_size,
1, low_performance_gain,
);
assert!(matches!(should_convert, true | false));
}
}