mod f32_cpu;
use crate::{
ComputeManager, VKMLError,
gpu::vk_gpu::Gpu,
instruction::{Instruction, concat::f32_cpu::f32_cpu},
tensor_graph::TensorId,
};
use onnx_extractor::DataType;
use std::fmt::{Debug, Formatter, Result as FmtResult};
use vulkanalia::vk;
pub struct ConcatInstruction {
pub sources: Vec<TensorId>,
pub dst: TensorId,
pub dim: usize,
}
impl Debug for ConcatInstruction {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(
f,
"Concat(sources={:?}, dst={}, dim={})",
self.sources, self.dst, self.dim
)
}
}
impl Instruction for ConcatInstruction {
fn get_input_tensor_ids(&self) -> Vec<TensorId> {
self.sources.clone()
}
fn get_output_tensor_ids(&self) -> Vec<TensorId> {
vec![self.dst]
}
fn remap_tensor_ids(&mut self, new_inputs: &[TensorId], new_outputs: &[TensorId]) {
if !new_inputs.is_empty() {
self.sources = new_inputs.to_vec();
}
if !new_outputs.is_empty() {
self.dst = new_outputs[0];
}
}
fn record_into_command_buffer(
&self,
_gpu: &Gpu,
_command_buffer: vk::CommandBuffer,
_cm: &ComputeManager,
) -> Result<(), VKMLError> {
Err(VKMLError::Instruction(
"GPU implementation of Concat not yet supported".to_string(),
))
}
fn execute_cpu(&self, cm: &ComputeManager) {
assert_eq!(
self.dim, 1,
"CPU Concat only implemented for dimension 1, got {}",
self.dim
);
assert!(
!self.sources.is_empty(),
"Concat requires at least one source tensor"
);
let first_source = self.sources[0];
let first_desc = cm.tensor_read(first_source).desc().dims();
assert_eq!(first_desc.len(), 2, "Concat only supports 2D tensors");
let batch_size = first_desc[0] as usize;
let mut total_features: usize = 0;
let mut source_features: Vec<usize> = Vec::with_capacity(self.sources.len());
for &src_id in &self.sources {
let src_desc = cm.tensor_read(src_id).desc().dims();
assert_eq!(
src_desc.len(),
2,
"All source tensors must be 2D for Concat"
);
assert_eq!(
src_desc[0] as usize, batch_size,
"All source tensors must have same batch size"
);
let feat = src_desc[1] as usize;
source_features.push(feat);
total_features += feat;
}
let dst_desc = cm.tensor_read(self.dst).desc().dims();
assert_eq!(dst_desc.len(), 2, "Destination tensor must be 2D");
assert_eq!(
dst_desc[0] as usize, batch_size,
"Destination batch size mismatch"
);
assert_eq!(
dst_desc[1] as usize, total_features,
"Destination feature size mismatch"
);
let mut owned_src_bytes: Vec<Vec<u8>> = Vec::with_capacity(self.sources.len());
let mut src_dims_vec: Vec<Vec<i64>> = Vec::with_capacity(self.sources.len());
for &src_id in &self.sources {
let src_tensor = cm.tensor_read(src_id);
let src_slice = src_tensor.get_cpu_memory_slice_or_panic();
owned_src_bytes.push(src_slice.to_vec());
src_dims_vec.push(src_tensor.desc().dims().to_vec());
}
let src_bytes_vec: Vec<&[u8]> = owned_src_bytes.iter().map(|v| v.as_slice()).collect();
let dst_tensor = cm.tensor_write(self.dst);
let op_datatype = dst_tensor.desc().data_type();
let dst_bytes = dst_tensor.get_cpu_memory_mut_slice_or_panic();
match op_datatype {
DataType::Float => {
f32_cpu(&src_bytes_vec, &src_dims_vec, self.dim, dst_desc, dst_bytes);
}
_ => unimplemented!(
"concat.rs unimplemented cpu instruction for DataType {:?}",
dst_tensor.desc().data_type()
),
}
}
}