vortx 0.2.0

Cross-platform GPU tensor library with Rust.
use crate::shapes::TensorLayoutBuffers;
use crate::tensor::{AsTensorMut, AsTensorRef};
use khal::Shader;
use khal::backend::{GpuBackend, GpuBackendError, GpuPass};

use crate::shaders::linalg::Repeat as GpuRepeat;

/// Module for replicating the content of a source tensor as many times as possible to fill
/// a destination tensor.
#[derive(Shader)]
pub struct Repeat {
    /// Kernel for replicating the content of a source tensor as many times as possible to fill
    /// a destination tensor.
    ///
    /// The shape of the destination tensor needs to be an integer multiple of the source tensor's
    /// shape.
    pub repeat: GpuRepeat,
}

// RepeatArgs is now generated by spirv_bindgen from vortx_shaders::linalg::repeat

impl Repeat {
    /// Launches the kernel that repeats the content of `source` into `destination` as many times
    /// as needed to fill `destination`.
    ///
    /// The shape of `destination` must be an integer multiple of the shape of `source`.
    pub fn launch(
        &self,
        backend: &GpuBackend,
        #[cfg_attr(feature = "push_constants", allow(unused_variables))]
        shapes: &mut TensorLayoutBuffers,
        pass: &mut GpuPass,
        mut destination: impl AsTensorMut<f32>,
        source: impl AsTensorRef<f32>,
    ) -> Result<(), GpuBackendError> {
        let mut result = destination.as_tensor_mut();
        let source = source.as_tensor_ref();

        let Some((mut result_shape, mut source_shape)) =
            result.layout().broadcast_assign(source.layout())
        else {
            // TODO: return an error instead of panic.
            panic!(
                "destination: {:?} is incompatible with source: {:?}",
                result.layout(),
                source.layout()
            )
        };

        result_shape = result_shape.canonicalize();
        source_shape = source_shape.canonicalize();

        let num_threads = result.len() as u32;

        #[cfg(not(feature = "push_constants"))]
        {
            shapes.insert(backend, source_shape)?;
            shapes.insert(backend, result_shape)?;
            let shape_source = shapes.get(source_shape).unwrap_or_else(|| unreachable!());
            let shape_result = shapes.get(result_shape).unwrap_or_else(|| unreachable!());
            let mut buf_result = result.buffer_mut();

            self.repeat.call(
                pass,
                num_threads,
                &shape_result.as_slice(),
                &shape_source.as_slice(),
                &mut buf_result,
                &source.buffer(),
            )
        }

        #[cfg(feature = "push_constants")]
        {
            let mut buf_result = result.buffer_mut();

            self.repeat.call(
                pass,
                num_threads,
                &mut buf_result,
                &source.buffer(),
                crate::shaders::linalg::Shapes2 {
                    shape_a: result_shape.into(),
                    shape_b: source_shape.into(),
                },
            )
        }
    }
}