web-rwkv 0.10.20

An implementation of the RWKV language model in pure WebGPU.
use web_rwkv_derive::{Deref, DerefMut};

use crate::{
    runtime::{JobInfo, JobInput},
    tensor::{TensorCpu, TensorShape},
};

#[derive(Debug, Default, Clone, Copy)]
pub struct Vision;

impl super::Infer for Vision {
    type Info = VisionInfo;
    type Input = VisionInput;
    type Output = VisionOutput;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deref, DerefMut)]
pub struct VisionInfo([usize; 4]);

impl JobInfo for VisionInfo {
    fn check(&self, info: &Self) -> bool {
        self.eq(info)
    }
}

/// An input picture consists of a list of patches. The shape of the tensor should be `[X, Y, C, N]`.
#[derive(Debug, Clone, Deref, DerefMut)]
pub struct VisionInput(pub TensorCpu<f32>);

impl VisionInput {
    #[inline]
    pub fn x(&self) -> usize {
        self.0.shape()[0]
    }

    #[inline]
    pub fn y(&self) -> usize {
        self.0.shape()[1]
    }

    #[inline]
    pub fn num_channel(&self) -> usize {
        self.0.shape()[2]
    }

    #[inline]
    pub fn num_patch(&self) -> usize {
        self.0.shape()[3]
    }

    #[inline]
    pub fn num_emb(&self) -> usize {
        self.x() * self.y() * self.num_channel()
    }

    #[inline]
    pub fn info(&self) -> VisionInfo {
        VisionInfo(self.shape().into())
    }
}

impl JobInput for VisionInput {
    type Chunk = Self;

    fn step(&mut self) {}

    fn chunk(&self) -> Self::Chunk {
        self.clone()
    }
}

impl Iterator for &VisionInput {
    type Item = VisionInfo;

    fn next(&mut self) -> Option<Self::Item> {
        Some(self.info())
    }
}

/// The embedding vector of the picture.
#[derive(Debug, Clone, Deref, DerefMut)]
pub struct VisionOutput(pub TensorCpu<f32>);