vkml 0.0.2

High-level Vulkan-based machine learning library
mod f32_f32_cpu;
mod push_constants;

use crate::ComputeManager;
use crate::VKMLError;
use crate::instruction::softmax::push_constants::SoftmaxPushConstants;
use crate::utils::as_bytes;
use crate::{
    gpu::vk_gpu::Gpu,
    instruction::{Instruction, gpu_operations::GPUOperation, softmax::f32_f32_cpu::f32_f32_cpu},
    tensor_graph::TensorId,
};
use onnx_extractor::DataType;
use std::fmt::{Debug, Formatter, Result as FmtResult};
use vulkanalia::vk;

pub struct SoftmaxInstruction {
    pub src: TensorId,
    pub dst: TensorId,
    pub axis: Option<i64>,
}

impl Debug for SoftmaxInstruction {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        write!(
            f,
            "Softmax(src={}, dst={}, axis={:?})",
            self.src, self.dst, self.axis
        )
    }
}

impl SoftmaxInstruction {
    fn resolve_axis(&self, rank: usize) -> usize {
        let axis = self.axis.unwrap_or(-1);
        if axis < 0 {
            (rank as i64 + axis) as usize
        } else {
            axis as usize
        }
    }
}

impl Instruction for SoftmaxInstruction {
    fn get_input_tensor_ids(&self) -> Vec<TensorId> {
        vec![self.src]
    }

    fn get_output_tensor_ids(&self) -> Vec<TensorId> {
        vec![self.dst]
    }

    fn remap_tensor_ids(&mut self, new_inputs: &[TensorId], new_outputs: &[TensorId]) {
        if !new_inputs.is_empty() {
            self.src = new_inputs[0];
        }

        if !new_outputs.is_empty() {
            self.dst = new_outputs[0];
        }
    }

    fn record_into_command_buffer(
        &self,
        gpu: &Gpu,
        command_buffer: vk::CommandBuffer,
        cm: &ComputeManager,
    ) -> Result<(), VKMLError> {
        let src_tensor = cm.tensor_read(self.src);
        let src_mem = src_tensor.get_gpu_memory_or_panic();
        let dst_tensor = cm.tensor_read(self.dst);
        let dst_mem = dst_tensor.get_gpu_memory_or_panic();

        let dims = src_tensor.desc().dims();
        let dim = self.resolve_axis(dims.len());

        // Currently we only support softmax on the last dimension
        assert_eq!(
            dim,
            dims.len() - 1,
            "Only softmax on the last dimension is currently implemented, requested dimension: {}",
            dim
        );

        let feature_size = dims[dim] as usize;
        let batch_size = src_tensor.desc().num_elements() / feature_size;

        // Create push constants struct (compute before GPU ops)
        let push_constants = SoftmaxPushConstants {
            batch_size: batch_size as u32,
            feature_size: feature_size as u32,
        };

        let pc_bytes = as_bytes(&push_constants);

        // Choose operation based on data type
        let src_dtype = src_tensor.desc().data_type();
        let dst_dtype = dst_tensor.desc().data_type();

        if src_dtype != dst_dtype {
            return Err(VKMLError::Instruction(format!(
                "GPU Softmax unimplemented for DataType src:{:?}, dst:{:?}",
                src_dtype, dst_dtype
            )));
        }

        // Standard path
        let gpu_op = match dst_dtype {
            DataType::Float => GPUOperation::Softmax_FP32,
            DataType::Float16 => GPUOperation::Softmax_FP16,
            _ => {
                return Err(VKMLError::Instruction(format!(
                    "GPU Softmax unsupported for DataType {:?}",
                    dst_dtype
                )));
            }
        };
        let local_size = [256, 1, 1];

        gpu.bind_slang_compute_pipeline(command_buffer, gpu_op, dst_dtype, local_size);
        gpu.bind_storage_buffers(command_buffer, &[src_mem, dst_mem]);
        gpu.bind_push_constants(command_buffer, gpu_op, pc_bytes);

        gpu.dispatch(
            command_buffer,
            local_size,
            [batch_size as u64 * local_size[0] as u64, 1, 1],
        );

        Ok(())
    }

    fn execute_cpu(&self, cm: &ComputeManager) {
        assert!(
            self.src != self.dst,
            "Cannot use Softmax for in-place operation"
        );

        let src_tensor = cm.tensor_read(self.src);
        let dst_tensor = cm.tensor_write(self.dst);

        let dims = src_tensor.desc().dims();
        let dim = self.resolve_axis(dims.len());

        assert_eq!(
            dim,
            dims.len() - 1,
            "CPU Softmax currently only supports the last dimension"
        );

        let src_dtype = src_tensor.desc().data_type();
        let dst_dtype = dst_tensor.desc().data_type();

        let src_bytes = src_tensor.get_cpu_memory_slice_or_panic();
        let dst_ptr = dst_tensor.get_cpu_memory_mut_slice_or_panic();

        match (src_dtype, dst_dtype) {
            (DataType::Float, DataType::Float) => {
                f32_f32_cpu(dims, dim, src_bytes, dst_ptr);
            }
            _ => unimplemented!(
                "softmax.rs unimplemented cpu instruction for DataType src:{:?}, dst:{:?}",
                src_dtype,
                dst_dtype
            ),
        }
    }
}