use crate::VKMLError;
use crate::{ComputeManager, gpu::vk_gpu::Gpu, instruction::Instruction, tensor_graph::TensorId};
use std::fmt::{Debug, Formatter, Result as FmtResult};
use vulkanalia::{vk, vk::DeviceV1_0};
pub struct ReshapeInstruction {
pub src: TensorId,
pub dst: TensorId,
pub shape_values: Vec<i64>,
pub allowzero: Option<i64>,
}
impl Debug for ReshapeInstruction {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(
f,
"Reshape(src={}, dst={}, shape={:?}, allowzero={:?})",
self.src, self.dst, self.shape_values, self.allowzero
)
}
}
impl Instruction for ReshapeInstruction {
fn get_input_tensor_ids(&self) -> Vec<TensorId> {
vec![self.src]
}
fn get_output_tensor_ids(&self) -> Vec<TensorId> {
vec![self.dst]
}
fn remap_tensor_ids(&mut self, new_inputs: &[TensorId], new_outputs: &[TensorId]) {
if !new_inputs.is_empty() {
self.src = new_inputs[0];
}
if !new_outputs.is_empty() {
self.dst = new_outputs[0];
}
}
fn record_into_command_buffer(
&self,
gpu: &Gpu,
command_buffer: vk::CommandBuffer,
cm: &ComputeManager,
) -> Result<(), VKMLError> {
let src_tensor = cm.tensor_read(self.src);
let src_mem = src_tensor.get_gpu_memory_or_panic();
let dst_tensor = cm.tensor_read(self.dst);
let dst_mem = dst_tensor.get_gpu_memory_or_panic();
unsafe {
let copy_region = vk::BufferCopy {
src_offset: 0,
dst_offset: 0,
size: src_mem.size,
};
gpu.get_device().cmd_copy_buffer(
command_buffer,
src_mem.buffer,
dst_mem.buffer,
&[copy_region],
);
}
Ok(())
}
fn execute_cpu(&self, cm: &ComputeManager) {
let mut new_dims = self.shape_values.clone();
let allowzero_flag = self.allowzero.unwrap_or(0) != 0;
if !allowzero_flag {
let src_desc = cm.tensor_read(self.src).desc().clone();
let src_dims = src_desc.dims();
for (i, val) in new_dims.iter_mut().enumerate() {
if *val == 0 {
*val = *src_dims.get(i).unwrap_or(&1);
}
}
} else {
if new_dims.contains(&0) && new_dims.contains(&(-1)) {
panic!("Reshape: 'allowzero' set but shape contains both 0 and -1");
}
}
let src_desc = cm.tensor_read(self.src).desc().clone();
let src_num = src_desc.num_elements();
let neg1_count = new_dims.iter().filter(|&&d| d == -1).count();
if neg1_count > 1 {
panic!("Reshape: more than one -1 in shape is not allowed");
}
if neg1_count == 1 {
let mut prod = 1usize;
for &d in &new_dims {
if d == -1 {
continue;
}
if d < 0 {
panic!("Reshape: negative dimensions other than -1 not allowed");
}
prod = prod.saturating_mul(d as usize);
}
if prod == 0 || !src_num.is_multiple_of(prod) {
panic!("Reshape: cannot infer -1 dimension at runtime");
}
let inferred = (src_num / prod) as i64;
for v in new_dims.iter_mut() {
if *v == -1 {
*v = inferred;
}
}
} else {
let prod: usize = new_dims.iter().map(|&d| d as usize).product();
if prod != src_num {
panic!("Reshape: total elements do not match input tensor at runtime");
}
}
{
let dst_t = cm.tensor_write(self.dst);
dst_t
.desc_mut()
.reshape(new_dims)
.expect("Invalid reshape at runtime");
}
let src_tensor = cm.tensor_read(self.src);
let dst_tensor = cm.tensor_write(self.dst);
dst_tensor.write(&src_tensor.read());
}
}