web-rwkv 0.10.20

An implementation of the RWKV language model in pure WebGPU.
use web_rwkv::{
    context::{BindGroupBuilder, Macros, PipelineKey},
    num::Float,
    tensor::{ops::TensorOp, TensorError, TensorGpuView, TensorShape},
};

pub trait TensorOpExt: Sized {
    fn mul_w<'a, 'b, F0: Float, F1: Float>(
        input: impl Into<TensorGpuView<'a, F0>>,
        output: impl Into<TensorGpuView<'b, F1>>,
    ) -> Result<Self, TensorError>;
}

impl TensorOpExt for TensorOp {
    fn mul_w<'a, 'b, F0: Float, F1: Float>(
        input: impl Into<TensorGpuView<'a, F0>>,
        output: impl Into<TensorGpuView<'b, F1>>,
    ) -> Result<Self, TensorError> {
        const BLOCK_SIZE: u32 = 128;

        let input: TensorGpuView<_> = input.into();
        let output: TensorGpuView<_> = output.into();

        let context = output.context();
        let shape = {
            let [index, token, batch, _] = output.shape().into();
            input
                .check_shape([index, 1, batch, 1])
                .or(input.check_shape([index, token, batch, 1]))?;
            output.check_shape([index, token, batch, 1])?;
            output.shape()
        };

        let key = PipelineKey::new(
            "mul_w",
            "mul_w",
            Macros::new()
                .u32("BLOCK_SIZE", BLOCK_SIZE)
                .tensor(&input, Some("IN"))
                .tensor(&output, Some("OUT")),
        );
        let pipeline = context.checkout_pipeline(
            &key,
            include_str!("mul_w.wgsl"),
            &[
                input.meta_layout(0),
                output.meta_layout(1),
                input.layout(2, true),
                output.layout(3, false),
            ],
        );

        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
            .bind_meta(0, &input)
            .bind_meta(1, &output)
            .bind(2, &input)
            .bind(3, &output)
            .build()];

        Ok(Self::Atom {
            pipeline,
            bindings,
            dispatch: [
                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
                shape[1] as u32,
                shape[2] as u32,
            ],
        })
    }
}