mod f32_f32_cpu;
mod push_constants;
use crate::ComputeManager;
use crate::VKMLError;
use crate::instruction::expand::f32_f32_cpu::f32_f32_cpu;
use crate::instruction::expand::push_constants::ExpandPushConstants;
use crate::utils::as_bytes;
use crate::{
gpu::vk_gpu::Gpu,
instruction::{Instruction, gpu_operations::GPUOperation},
tensor::TensorDesc,
tensor_graph::TensorId,
};
use onnx_extractor::DataType;
use std::fmt::{Debug, Formatter, Result as FmtResult};
use vulkanalia::vk;
pub struct ExpandInstruction {
pub src: TensorId,
pub dst: TensorId,
pub shape_values: Vec<i64>,
}
impl Debug for ExpandInstruction {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(
f,
"Expand(src={}, dst={}, shape={:?})",
self.src, self.dst, self.shape_values
)
}
}
impl Instruction for ExpandInstruction {
fn get_input_tensor_ids(&self) -> Vec<TensorId> {
vec![self.src]
}
fn get_output_tensor_ids(&self) -> Vec<TensorId> {
vec![self.dst]
}
fn remap_tensor_ids(&mut self, new_inputs: &[TensorId], new_outputs: &[TensorId]) {
if !new_inputs.is_empty() {
self.src = new_inputs[0];
}
if !new_outputs.is_empty() {
self.dst = new_outputs[0];
}
}
fn record_into_command_buffer(
&self,
gpu: &Gpu,
command_buffer: vk::CommandBuffer,
cm: &ComputeManager,
) -> Result<(), VKMLError> {
let src_tensor = cm.tensor_read(self.src);
let src_mem = src_tensor.get_gpu_memory_or_panic();
let dst_tensor = cm.tensor_read(self.dst);
let dst_mem = dst_tensor.get_gpu_memory_or_panic();
let src_desc = src_tensor.desc();
let dst_desc = dst_tensor.desc();
let src_dims_usize = src_desc.dims();
let dst_dims_usize = dst_desc.dims();
let rank = dst_dims_usize.len() as u32;
assert!(
rank <= 8,
"Expand: tensor rank {} exceeds maximum supported rank of 8",
rank
);
let mut dims_arr = [0u32; 8];
for (i, &d) in dst_dims_usize.iter().enumerate().take(8) {
dims_arr[i] = d as u32;
}
let strides_src_usize = TensorDesc::broadcast_strides(src_dims_usize, dst_dims_usize);
let mut strides_src_arr = [0u32; 8];
for (i, &s) in strides_src_usize.iter().enumerate().take(8) {
strides_src_arr[i] = s as u32;
}
let total_elements: u64 = dst_dims_usize.iter().map(|d| *d as u64).product();
let push_const_values = ExpandPushConstants {
rank,
pad: 0,
total: total_elements as u32,
dims: dims_arr,
strides_src: strides_src_arr,
};
let push_constant_bytes = as_bytes(&push_const_values);
let src_dtype = src_desc.data_type();
let dst_dtype = dst_desc.data_type();
if src_dtype != dst_dtype {
return Err(VKMLError::Instruction(format!(
"GPU Expand unimplemented for DataType src:{:?}, dst:{:?}",
src_dtype, dst_dtype
)));
}
let gpu_op = GPUOperation::Expand;
let local_size = gpu.optimal_workgroup_size_1d(total_elements);
gpu.bind_slang_compute_pipeline(command_buffer, gpu_op, dst_dtype, local_size);
gpu.bind_storage_buffers(command_buffer, &[src_mem, dst_mem]);
gpu.bind_push_constants(command_buffer, gpu_op, push_constant_bytes);
let num_elements: u64 = dst_dims_usize.iter().map(|d| *d as u64).product();
gpu.dispatch(command_buffer, local_size, [num_elements, 1, 1]);
Ok(())
}
fn execute_cpu(&self, cm: &ComputeManager) {
let src_tensor = cm.tensor_read(self.src);
let dst_tensor = cm.tensor_write(self.dst);
let src_dims = src_tensor.desc().dims();
let dst_dims = dst_tensor.desc().dims().to_vec();
let src_rank = src_dims.len();
let dst_rank = dst_dims.len();
let mut padded_src_dims = vec![1; dst_rank];
let offset = dst_rank.saturating_sub(src_rank);
for (i, &dim) in src_dims.iter().enumerate() {
padded_src_dims[offset + i] = dim;
}
for i in 0..dst_rank {
let src_dim = padded_src_dims[i];
let dst_dim = dst_dims[i];
if src_dim != dst_dim && src_dim != 1 {
panic!(
"Expand: incompatible shapes src={:?} (padded={:?}), dst={:?}",
src_dims, padded_src_dims, dst_dims
);
}
}
let strides_src = TensorDesc::broadcast_strides(src_dims, &dst_dims);
let src_dtype = src_tensor.desc().data_type();
let dst_dtype = dst_tensor.desc().data_type();
let src_bytes = src_tensor.get_cpu_memory_slice_or_panic();
let dst_ptr = dst_tensor.get_cpu_memory_mut_slice_or_panic();
match (src_dtype, dst_dtype) {
(DataType::Float, DataType::Float) => {
f32_f32_cpu(strides_src, dst_dims, src_bytes, dst_ptr)
}
_ => unimplemented!(
"expand.rs unimplemented cpu instruction for DataType src:{:?}, dst:{:?}",
src_dtype,
dst_dtype
),
}
}
}