vortx 0.2.0

Cross-platform GPU tensor library with Rust.
use crate::shapes::TensorLayoutBuffers;
use crate::tensor::{AsTensorMut, AsTensorRef};
use khal::Shader;
use khal::backend::{GpuBackend, GpuBackendError, GpuPass};

use crate::shaders::linalg::{ReduceAdd, ReduceMax, ReduceMin, ReduceMul, ReduceSqNorm};

#[cfg(test)]
use nalgebra::DVector;

/// The desired operation for the [`Reduce`] kernel.
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
#[non_exhaustive]
pub enum ReduceVariant {
    /// Minimum: `result = min(input[0], min(input[1], ...))`
    Min,
    /// Maximum: `result = max(input[0], max(input[1], ...))`
    Max,
    /// Sum: `result = input[0] + input[1] ...`
    Sum,
    /// Product: `result = input[0] * input[1] ...`
    Prod,
    /// Squared norm: `result = input[0] * input[0] + input[1] * input[1] ...`
    SqNorm,
}

impl ReduceVariant {
    #[cfg(test)]
    fn eval(self, val: &DVector<f32>) -> f32 {
        match self {
            ReduceVariant::Min => val.min(),
            ReduceVariant::Max => val.max(),
            ReduceVariant::Prod => val.product(),
            ReduceVariant::Sum => val.sum(),
            ReduceVariant::SqNorm => val.norm_squared(),
        }
    }
}

/// A GPU kernel for performing the operation described by [`ReduceVariant`].
#[derive(Shader)]
pub struct Reduce {
    /// Kernel for computing the sum of every element of a tensor.
    pub reduce_sum: ReduceAdd,
    /// Kernel for computing the product of every element of a tensor.
    pub reduce_product: ReduceMul,
    /// Kernel for computing the minimum element of a tensor.
    pub reduce_min: ReduceMin,
    /// Kernel for computing the maximum element of a tensor.
    pub reduce_max: ReduceMax,
    /// Kernel for computing the squared norm of a tensor.
    pub reduce_sqnorm: ReduceSqNorm,
}

// ReduceArgs is now generated by spirv_bindgen from vortx_shaders::linalg::reduce

impl Reduce {
    /// Launches the reduction kernel for the given variant.
    pub fn launch(
        &self,
        backend: &GpuBackend,
        #[cfg_attr(feature = "push_constants", allow(unused_variables))]
        shapes: &mut TensorLayoutBuffers,
        pass: &mut GpuPass,
        variant: ReduceVariant,
        input: impl AsTensorRef<f32>,
        mut output: impl AsTensorMut<f32>,
    ) -> Result<(), GpuBackendError> {
        let input = input.as_tensor_ref();
        let mut output = output.as_tensor_mut();

        let shape = input.layout().canonicalize();

        match backend {
            #[cfg(feature = "cpu")]
            GpuBackend::Cpu => {
                let shape: vortx_shaders::linalg::Shape = shape.into();
                let input_buf = input.buffer();
                let input = input_buf.unwrap_slice();
                let mut output_buf = output.buffer_mut();
                let output = output_buf.unwrap_slice();
                let len = shape.w as usize;

                let result = match variant {
                    ReduceVariant::Sum => {
                        let mut acc = 0.0f32;
                        for i in 0..len {
                            let idx = shape.it(0, 0, 0, i as u32) as usize;
                            acc += input[idx];
                        }
                        acc
                    }
                    ReduceVariant::Prod => {
                        let mut acc = 1.0f32;
                        for i in 0..len {
                            let idx = shape.it(0, 0, 0, i as u32) as usize;
                            acc *= input[idx];
                        }
                        acc
                    }
                    ReduceVariant::Min => {
                        let mut acc = f32::MAX;
                        for i in 0..len {
                            let idx = shape.it(0, 0, 0, i as u32) as usize;
                            acc = acc.min(input[idx]);
                        }
                        acc
                    }
                    ReduceVariant::Max => {
                        let mut acc = f32::MIN;
                        for i in 0..len {
                            let idx = shape.it(0, 0, 0, i as u32) as usize;
                            acc = acc.max(input[idx]);
                        }
                        acc
                    }
                    ReduceVariant::SqNorm => {
                        let mut acc = 0.0f32;
                        for i in 0..len {
                            let idx = shape.it(0, 0, 0, i as u32) as usize;
                            let val = input[idx];
                            acc += val * val;
                        }
                        acc
                    }
                };

                output[0] = result;
                Ok(())
            }
            _ => {
                #[cfg(not(feature = "push_constants"))]
                {
                    shapes.insert(backend, shape)?;
                    let shape_buf = shapes.get(shape).unwrap();
                    let mut buf_output = output.buffer_mut();

                    macro_rules! call(
                        ($kernel: expr) => {
                            $kernel.call(
                                pass,
                                1u32,
                                &shape_buf.as_slice(),
                                &input.buffer(),
                                &mut buf_output,
                            )
                        }
                    );

                    match variant {
                        ReduceVariant::Sum => call!(self.reduce_sum),
                        ReduceVariant::Prod => call!(self.reduce_product),
                        ReduceVariant::Min => call!(self.reduce_min),
                        ReduceVariant::Max => call!(self.reduce_max),
                        ReduceVariant::SqNorm => call!(self.reduce_sqnorm),
                    }
                }

                #[cfg(feature = "push_constants")]
                {
                    let mut buf_output = output.buffer_mut();

                    macro_rules! call(
                        ($kernel: expr) => {
                            $kernel.call(
                                pass,
                                1u32,
                                &input.buffer(),
                                &mut buf_output,
                                crate::shaders::linalg::Shapes1 {
                                    shape: shape.into(),
                                },
                            )
                        }
                    );

                    match variant {
                        ReduceVariant::Sum => call!(self.reduce_sum),
                        ReduceVariant::Prod => call!(self.reduce_product),
                        ReduceVariant::Min => call!(self.reduce_min),
                        ReduceVariant::Max => call!(self.reduce_max),
                        ReduceVariant::SqNorm => call!(self.reduce_sqnorm),
                    }
                }
            }
        }
    }
}

#[cfg(test)]
mod test {
    use super::ReduceVariant;
    use crate::shapes::TensorLayoutBuffers;
    use crate::tensor::Tensor;
    use khal::BufferUsages;
    use khal::backend::{Backend, Encoder, GpuBackend, WebGpu};
    use khal::shader::Shader;
    use nalgebra::DVector;

    #[futures_test::test]
    #[serial_test::serial]
    async fn gpu_reduce_webgpu() {
        let webgpu = WebGpu::default().await.unwrap();
        let backend = GpuBackend::WebGpu(webgpu);
        gpu_reduce_generic(&backend).await;
    }

    #[cfg(feature = "cpu")]
    #[futures_test::test]
    async fn gpu_reduce_cpu() {
        gpu_reduce_generic(&GpuBackend::Cpu).await;
    }

    #[cfg(feature = "cuda")]
    #[futures_test::test]
    async fn gpu_reduce_cuda() {
        let cuda = GpuBackend::Cuda(khal::backend::cuda::Cuda::new(0).unwrap());
        gpu_reduce_generic(&cuda).await;
    }

    #[cfg(feature = "metal")]
    #[futures_test::test]
    #[serial_test::serial]
    async fn gpu_reduce_metal() {
        let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap());
        gpu_reduce_generic(&metal).await;
    }

    async fn gpu_reduce_generic(backend: &GpuBackend) {
        let ops = [
            ReduceVariant::Min,
            ReduceVariant::Max,
            ReduceVariant::Sum,
            ReduceVariant::Prod,
            ReduceVariant::SqNorm,
        ];
        let reduce = super::Reduce::from_backend(backend).unwrap();

        for op in ops {
            println!("Testing: {:?}", op);

            let mut shapes = TensorLayoutBuffers::new(backend);
            let mut encoder = backend.begin_encoding();

            const LEN: u32 = 345;

            let v = DVector::new_random(LEN as usize);
            let mut gpu_result = [1.0];
            let gpu_v = Tensor::vector(backend, &v, BufferUsages::STORAGE).unwrap();
            let mut gpu_out =
                Tensor::scalar(backend, 0.0, BufferUsages::STORAGE | BufferUsages::COPY_SRC)
                    .unwrap();

            let mut pass = encoder.begin_pass("reduce", None);
            reduce
                .launch(backend, &mut shapes, &mut pass, op, &gpu_v, &mut gpu_out)
                .unwrap();
            drop(pass); // Ensure the pass is ended before the encoder is borrowed again.

            backend.submit(encoder).unwrap();
            backend
                .slow_read_buffer(gpu_out.buffer(), &mut gpu_result)
                .await
                .unwrap();

            let cpu_result = op.eval(&v);

            approx::assert_relative_eq!(gpu_result[0], cpu_result, epsilon = 1.0e-3);
        }
    }
}