use onnx_extractor::DataType;
use crate::{instruction, tensor::TensorDesc, utils::error::VKMLError};
use super::{Layer, execution::LayerExecution};
pub struct ConcatLayer {
pub dim: usize,
}
impl ConcatLayer {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl Layer for ConcatLayer {
fn name(&self) -> String {
"Concat".to_string()
}
fn config_string(&self) -> Option<String> {
Some(format!("dim={}", self.dim))
}
fn requires_parameters(&self) -> bool {
false
}
fn input_requirements(&self) -> (usize, Option<usize>) {
(2, None) }
fn parameter_shapes(&self, _input_shapes: &[&TensorDesc]) -> Option<(TensorDesc, TensorDesc)> {
None }
fn output_shapes(
&self,
_batch_size: i64,
input_shapes: &[&TensorDesc],
) -> Result<Vec<TensorDesc>, VKMLError> {
if input_shapes.len() < 2 {
return Err(VKMLError::Layer(format!(
"Concat layer requires at least 2 inputs, got {}",
input_shapes.len()
)));
}
let first_dims = input_shapes[0].dims();
let ndim = first_dims.len();
for shape in input_shapes.iter().skip(1) {
if shape.dims().len() != ndim {
return Err(VKMLError::Layer(
"All inputs to Concat must have same number of dimensions".to_string(),
));
}
}
if self.dim >= ndim {
return Err(VKMLError::Layer(format!(
"Concat dimension {} out of range for {}-dimensional tensors",
self.dim, ndim
)));
}
for (d, &size) in first_dims.iter().enumerate().take(ndim) {
if d == self.dim {
continue;
}
for shape in input_shapes.iter().skip(1) {
if shape.dims()[d] != size {
return Err(VKMLError::Layer(format!(
"Dimension {} must have same size for all inputs to Concat",
d
)));
}
}
}
let mut output_dims = first_dims.to_vec();
output_dims[self.dim] = input_shapes
.iter()
.map(|shape| shape.dims()[self.dim])
.sum();
let output_shape = TensorDesc::new(output_dims, DataType::Float);
Ok(vec![output_shape])
}
fn build_layer_exec(
&self,
batch_size: i64,
input_shapes: &[&TensorDesc],
) -> Result<LayerExecution, VKMLError> {
if input_shapes.len() < 2 {
return Err(VKMLError::Layer(format!(
"Concat layer requires at least 2 inputs, got {}",
input_shapes.len()
)));
}
let mut tensors = Vec::new();
let mut input_tensor_indices = Vec::new();
for shape in input_shapes {
let idx = tensors.len();
tensors.push((*shape).clone());
input_tensor_indices.push(idx);
}
let output_shapes = self.output_shapes(batch_size, input_shapes)?;
let output_shape = output_shapes[0].clone();
let output_idx = tensors.len();
tensors.push(output_shape);
let instruction = instruction::concat(input_tensor_indices, output_idx, self.dim);
let input_mappings = self.map_input_tensors(input_shapes.len());
Ok(LayerExecution {
tensors,
instructions: vec![instruction],
outputs: vec![output_idx],
input_mappings,
})
}
}