use onnx_extractor::DataType;
use crate::utils::onnx_autopad::calc_begin_and_end_pads;
use crate::{
instruction,
tensor::TensorDesc,
utils::{OnnxAutoPad, error::VKMLError},
};
use super::{Layer, execution::LayerExecution};
#[derive(Clone)]
pub struct ConvLayer {
pub in_features: i64, pub out_features: i64, pub auto_pad: OnnxAutoPad,
pub dilations: Vec<i64>,
pub kernel_shape: Vec<i64>,
pub pads: Vec<i64>,
pub strides: Vec<i64>,
pub bias: bool,
}
impl ConvLayer {
pub fn new_with(
in_features: i64,
out_features: i64,
auto_pad: OnnxAutoPad,
dilations: Vec<i64>,
kernel_shape: Vec<i64>,
pads: Vec<i64>,
strides: Vec<i64>,
bias: bool,
) -> Self {
Self {
in_features,
out_features,
auto_pad,
dilations,
kernel_shape,
pads,
strides,
bias,
}
}
}
impl Layer for ConvLayer {
fn output_shapes(
&self,
batch_size: i64,
input_shapes: &[&TensorDesc],
) -> Result<Vec<TensorDesc>, VKMLError> {
if input_shapes.len() != 1 {
return Err(VKMLError::Layer(format!(
"Conv layer requires exactly 1 input, got {}",
input_shapes.len()
)));
}
let input_shape = input_shapes[0];
let ndim = input_shape.ndim();
if ndim < 3 {
return Err(VKMLError::Layer(format!(
"Conv requires input tensor with at least 3 dims (N,C,spatial...), got {:?}",
input_shape
)));
}
let spatial_rank = ndim - 2;
let input_dims = input_shape.dims();
let in_channels = input_dims[1];
if input_dims[1] != self.in_features {
return Err(VKMLError::Layer(format!(
"Conv expected {} input channels, got {}",
self.in_features, in_channels
)));
}
let mut spatial_input: Vec<i64> = Vec::with_capacity(spatial_rank);
for i in 0..spatial_rank {
spatial_input.push(input_dims[2 + i]);
}
let (pads_begin, pads_end) = calc_begin_and_end_pads(
self.auto_pad.clone(),
&self.pads,
&self.kernel_shape,
&self.strides,
&self.dilations,
input_shape,
);
let mut out_spatial: Vec<i64> = Vec::with_capacity(spatial_rank);
for i in 0..spatial_rank {
let in_i = spatial_input[i];
let k = self.kernel_shape.get(i).copied().unwrap_or(1);
let s = self.strides.get(i).copied().unwrap_or(1);
let d = self.dilations.get(i).copied().unwrap_or(1);
let p_begin = pads_begin[i];
let p_end = pads_end[i];
let numerator = in_i + p_begin + p_end - d * (k - 1) - 1;
let out_i = (numerator / s) + 1;
out_spatial.push(out_i);
}
let mut out_dims = vec![batch_size, self.out_features];
out_dims.extend(out_spatial.iter());
Ok(vec![TensorDesc::new(out_dims, DataType::Float)])
}
fn parameter_shapes(&self, _input_shapes: &[&TensorDesc]) -> Option<(TensorDesc, TensorDesc)> {
let mut w_dims: Vec<i64> = Vec::with_capacity(2 + self.kernel_shape.len());
w_dims.push(self.out_features);
w_dims.push(self.in_features); for &k in &self.kernel_shape {
w_dims.push(k);
}
let weights = TensorDesc::new(w_dims, DataType::Float);
let biases = TensorDesc::new(vec![self.out_features], DataType::Float);
Some((weights, biases))
}
fn parameter_count(&self, _batch_size: i64, _input_shapes: &[&TensorDesc]) -> i64 {
let mut kernel_prod: i64 = 1;
if !self.kernel_shape.is_empty() {
for &k in &self.kernel_shape {
kernel_prod *= k;
}
}
let weight_params = self.out_features * self.in_features * kernel_prod;
let bias_params = if self.bias { self.out_features } else { 0 };
weight_params + bias_params
}
fn input_requirements(&self) -> (usize, Option<usize>) {
(1, Some(1))
}
fn name(&self) -> String {
"Conv".to_string()
}
fn config_string(&self) -> Option<String> {
let ks = if self.kernel_shape.is_empty() {
"[1]".to_string()
} else {
format!(
"[{}]",
self.kernel_shape
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
)
};
let ss = if self.strides.is_empty() {
"[1]".to_string()
} else {
format!(
"[{}]",
self.strides
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
)
};
let ps = if self.pads.is_empty() {
"[]".to_string()
} else {
format!(
"[{}]",
self.pads
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
)
};
let ds = if self.dilations.is_empty() {
"[1]".to_string()
} else {
format!(
"[{}]",
self.dilations
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
)
};
Some(format!(
"auto_pad={:?}, kernel={}, stride={}, dilation={}, pads={}, bias={}",
self.auto_pad, ks, ss, ds, ps, self.bias
))
}
fn in_features(&self) -> i64 {
self.in_features
}
fn out_features(&self) -> i64 {
self.out_features
}
fn build_layer_exec(
&self,
batch_size: i64,
input_shapes: &[&TensorDesc],
) -> Result<LayerExecution, VKMLError> {
if input_shapes.is_empty() {
return Err(VKMLError::Layer("Conv layer requires an input".to_string()));
}
let input_shape = input_shapes[0];
if input_shape.ndim() < 3 {
return Err(VKMLError::Layer(
"Conv layer expects at least 3D tensor input".into(),
));
}
let input_dims = input_shape.dims();
let in_channels = input_dims[1];
if in_channels != self.in_features {
return Err(VKMLError::Layer(format!(
"Conv layer expects {} input channels, got {}",
self.in_features, in_channels
)));
}
let spatial_rank = input_shape.ndim() - 2;
let mut spatial_input: Vec<i64> = Vec::with_capacity(spatial_rank);
for i in 0..spatial_rank {
spatial_input.push(input_dims[2 + i]);
}
let mut out_spatial: Vec<i64> = Vec::with_capacity(spatial_rank);
let mut kernel: Vec<i64> = Vec::with_capacity(spatial_rank);
let mut stride: Vec<i64> = Vec::with_capacity(spatial_rank);
let mut dilation: Vec<i64> = Vec::with_capacity(spatial_rank);
for i in 0..spatial_rank {
kernel.push(self.kernel_shape.get(i).copied().unwrap_or(1));
stride.push(self.strides.get(i).copied().unwrap_or(1));
dilation.push(self.dilations.get(i).copied().unwrap_or(1));
}
let (pads_begin, pads_end) = calc_begin_and_end_pads(
self.auto_pad.clone(),
&self.pads,
&self.kernel_shape,
&self.strides,
&self.dilations,
input_shape,
);
for i in 0..spatial_rank {
let in_i = spatial_input[i];
let k = kernel[i];
let s = stride[i];
let d = dilation[i];
let p_begin = pads_begin[i];
let p_end = pads_end[i];
let numerator = in_i + p_begin + p_end - d * (k - 1) - 1;
let out_i = (numerator / s) + 1;
out_spatial.push(out_i);
}
let mut tensors = Vec::new();
tensors.push(input_shape.clone());
let mut w_dims: Vec<i64> = Vec::with_capacity(2 + self.kernel_shape.len());
w_dims.push(self.out_features);
w_dims.push(self.in_features);
for &k in &self.kernel_shape {
w_dims.push(k);
}
tensors.push(TensorDesc::new(w_dims, DataType::Float));
let mut out_dims: Vec<i64> = Vec::with_capacity(2 + out_spatial.len());
out_dims.push(batch_size);
out_dims.push(self.out_features);
out_dims.extend(out_spatial.iter());
tensors.push(TensorDesc::new(out_dims, DataType::Float));
let mut bias_idx = None;
if self.bias {
bias_idx = Some(tensors.len());
tensors.push(TensorDesc::new(vec![self.out_features], DataType::Float));
}
let instruction = instruction::conv(
0, 1, bias_idx, 2, self.auto_pad.clone(), self.dilations.clone(), 1, self.kernel_shape.clone(), self.pads.clone(), self.strides.clone(), );
let input_mappings = self.map_input_tensors(input_shapes.len());
Ok(LayerExecution {
tensors,
instructions: vec![instruction],
outputs: vec![2],
input_mappings,
})
}
}