web-rwkv 0.10.20

An implementation of the RWKV language model in pure WebGPU.
use half::f16;
use serde::{Deserialize, Serialize};

use super::{JobInfo, JobInput};
use crate::tensor::{TensorCpu, TensorInit};

pub mod rnn;
pub mod vision;

pub use rnn::{
    Rnn, RnnChunk, RnnChunkBatch, RnnInfo, RnnInfoBatch, RnnInput, RnnInputBatch, RnnIter,
    RnnOption, RnnOutput, RnnOutputBatch, RnnRedirect,
};

pub trait Infer: Send + Sync + 'static {
    type Info: JobInfo;
    type Input: JobInput;
    type Output: Send + Sync + 'static;
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Token {
    Token(u32),
    Embed(TensorCpu<f16>),
}

impl Default for Token {
    fn default() -> Self {
        Self::Token(0)
    }
}

impl From<u32> for Token {
    fn from(value: u32) -> Self {
        Self::Token(value)
    }
}

impl From<u16> for Token {
    fn from(value: u16) -> Self {
        Self::Token(value as u32)
    }
}

impl From<Vec<f16>> for Token {
    fn from(value: Vec<f16>) -> Self {
        Self::Embed(TensorCpu::from_data_1d(value))
    }
}

impl From<Vec<f32>> for Token {
    fn from(value: Vec<f32>) -> Self {
        let value: Vec<_> = value.into_iter().map(f16::from_f32).collect();
        Self::Embed(TensorCpu::from_data_1d(value))
    }
}