vkml 0.0.3

High-level Vulkan-based machine learning library
mod f32_f32_f32_cpu;
mod push_constants;

use crate::ComputeManager;
use crate::VKMLError;
use crate::instruction::mul::push_constants::MulPushConstants;
use crate::utils::as_bytes;
use crate::{
    gpu::vk_gpu::Gpu,
    instruction::{
        Instruction, gpu_operations::GPUOperation, mul::f32_f32_f32_cpu::f32_f32_f32_cpu,
    },
    tensor::TensorDesc,
    tensor_graph::TensorId,
};
use onnx_extractor::DataType;
use std::fmt::{Debug, Formatter, Result as FmtResult};
use vulkanalia::vk;

pub struct MulInstruction {
    pub src1: TensorId,
    pub src2: TensorId,
    pub dst: TensorId,
}

impl Debug for MulInstruction {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        write!(
            f,
            "Mul(src1={}, src2={}, dst={})",
            self.src1, self.src2, self.dst
        )
    }
}

impl Instruction for MulInstruction {
    fn get_input_tensor_ids(&self) -> Vec<TensorId> {
        vec![self.src1, self.src2]
    }

    fn get_output_tensor_ids(&self) -> Vec<TensorId> {
        vec![self.dst]
    }

    fn remap_tensor_ids(&mut self, new_inputs: &[TensorId], new_outputs: &[TensorId]) {
        if new_inputs.len() >= 2 {
            self.src1 = new_inputs[0];
            self.src2 = new_inputs[1];
        }

        if !new_outputs.is_empty() {
            self.dst = new_outputs[0];
        }
    }

    fn record_into_command_buffer(
        &self,
        gpu: &Gpu,
        command_buffer: vk::CommandBuffer,
        cm: &ComputeManager,
    ) -> Result<(), VKMLError> {
        let src1_tensor = cm.tensor_read(self.src1);
        let src1_mem = src1_tensor.get_gpu_memory_or_panic();
        let src2_tensor = cm.tensor_read(self.src2);
        let src2_mem = src2_tensor.get_gpu_memory_or_panic();
        let dst_tensor = cm.tensor_read(self.dst);
        let dst_mem = dst_tensor.get_gpu_memory_or_panic();

        // Get tensor descriptions for calculating broadcast shapes and strides
        let src1_desc = src1_tensor.desc();
        let src2_desc = src2_tensor.desc();
        let dst_desc = dst_tensor.desc();

        let src1_dims = src1_desc.dims();
        let src2_dims = src2_desc.dims();
        let dst_dims = dst_desc.dims();

        // Prepare push constant data using shared PushConstants

        let rank = dst_dims.len() as u32;
        assert!(
            rank <= 8,
            "Mul: tensor rank {} exceeds maximum supported rank of 8",
            rank
        );

        let mut dims_arr = [0u32; 8];
        for (i, &d) in dst_dims.iter().enumerate().take(8) {
            dims_arr[i] = d as u32;
        }

        // Calculate broadcast shape and strides (similar to execute_cpu)
        let broadcast_dims =
            TensorDesc::broadcast_shape(src1_dims, src2_dims).unwrap_or_else(|| {
                panic!(
                    "GPU Mul: Can't broadcast {:?} vs {:?}",
                    src1_dims, src2_dims
                )
            });

        assert_eq!(
            broadcast_dims.as_slice(),
            dst_dims,
            "GPU Mul: Broadcast shape {:?} != dst shape {:?}",
            broadcast_dims,
            dst_dims
        );

        let strides_a_usize = TensorDesc::broadcast_strides(src1_dims, dst_dims);
        let strides_b_usize = TensorDesc::broadcast_strides(src2_dims, dst_dims);

        let mut strides_a_arr = [0u32; 8];
        for (i, &s) in strides_a_usize.iter().enumerate().take(8) {
            strides_a_arr[i] = s as u32;
        }

        let mut strides_b_arr = [0u32; 8];
        for (i, &s) in strides_b_usize.iter().enumerate().take(8) {
            strides_b_arr[i] = s as u32;
        }

        let total_elements: u64 = dst_dims.iter().map(|d| *d as u64).product();

        let push_const_values = MulPushConstants {
            rank,
            pad: 0, // Padding value, 0 is fine
            total: total_elements as u32,
            dims: dims_arr,
            strides_a: strides_a_arr,
            strides_b: strides_b_arr,
        };

        let push_constant_bytes = as_bytes(&push_const_values);

        let src1_dtype = src1_desc.data_type();
        let src2_dtype = src2_desc.data_type();
        let dst_dtype = dst_desc.data_type();

        if src1_dtype != src2_dtype || src1_dtype != dst_dtype {
            return Err(VKMLError::Instruction(format!(
                "GPU Mul unimplemented for mixed DataType src1:{:?}, src2:{:?}, dst:{:?}",
                src1_dtype, src2_dtype, dst_dtype
            )));
        }

        let op_name = GPUOperation::Multiply;

        // Choose 1D local workgroup size and bind pipeline/descriptors
        let local_size = gpu.optimal_workgroup_size_1d(total_elements);

        gpu.bind_slang_compute_pipeline(command_buffer, op_name, dst_dtype, local_size);
        gpu.bind_storage_buffers(command_buffer, &[src1_mem, src2_mem, dst_mem]);

        // Push constants to the shader
        gpu.bind_push_constants(command_buffer, op_name, push_constant_bytes);

        let num_elements: u64 = dst_dims.iter().map(|d| *d as u64).product();

        gpu.dispatch(command_buffer, local_size, [num_elements, 1, 1]);

        Ok(())
    }

    fn execute_cpu(&self, cm: &ComputeManager) {
        assert!(
            self.src1 != self.dst && self.src2 != self.dst,
            "Cannot use Mul for in-place operation. Use MulInplace instead."
        );

        let src1_tensor = cm.tensor_read(self.src1);
        let src2_tensor = cm.tensor_read(self.src2);
        let dst_tensor = cm.tensor_write(self.dst);

        let a = src1_tensor.desc().dims();
        let b = src2_tensor.desc().dims();
        let c = dst_tensor.desc().dims().to_vec();

        let bc = TensorDesc::broadcast_shape(a, b)
            .unwrap_or_else(|| panic!("Can't broadcast {:?} vs {:?}", a, b));
        assert_eq!(bc, c, "Broadcast {:?} != dst {:?}", bc, c);

        let sa = TensorDesc::broadcast_strides(a, &c);
        let sb = TensorDesc::broadcast_strides(b, &c);

        let src1_dtype = src1_tensor.desc().data_type();
        let src2_dtype = src2_tensor.desc().data_type();
        let dst_dtype = dst_tensor.desc().data_type();

        let src1_bytes = src1_tensor.get_cpu_memory_slice_or_panic();
        let src2_bytes = src2_tensor.get_cpu_memory_slice_or_panic();
        let dst_ptr = dst_tensor.get_cpu_memory_mut_slice_or_panic();

        match (src1_dtype, src2_dtype, dst_dtype) {
            (DataType::Float, DataType::Float, DataType::Float) => {
                f32_f32_f32_cpu(sa, sb, c, src1_bytes, src2_bytes, dst_ptr)
            }
            _ => unimplemented!(
                "mul.rs unimplemented cpu instruction for DataType src1:{:?}, src2:{:?}, dst:{:?}",
                src1_dtype,
                src2_dtype,
                dst_dtype
            ),
        }
    }
}