web-rwkv 0.10.20

An implementation of the RWKV language model in pure WebGPU.
use crate::{
    context::Context,
    num::Float,
    tensor::{ops::TensorOp, TensorCpu, TensorError, TensorGpu, TensorInto},
};

pub async fn softmax_one<T: Float>(
    context: &Context,
    input: TensorCpu<T>,
) -> Result<TensorCpu<T>, TensorError> {
    if input.size() == 0 {
        return Ok(input);
    }

    let tensor: TensorGpu<_, _> = input.to(context);
    let op = TensorOp::softmax(&tensor)?;
    context.queue.submit(context.encode(&op));

    let output = tensor.back().await;
    Ok(output)
}

pub async fn softmax<T: Float>(
    context: &Context,
    input: Vec<TensorCpu<T>>,
) -> Result<Vec<TensorCpu<T>>, TensorError> {
    let mut tensors = Vec::with_capacity(input.len());
    let mut ops = Vec::with_capacity(input.len());

    for input in input.into_iter() {
        let tensor: TensorGpu<_, _> = input.to(context);
        if tensor.size() > 0 {
            ops.push(TensorOp::softmax(&tensor)?);
        }
        tensors.push(tensor);
    }
    context.queue.submit(context.encode(&TensorOp::List(ops)));

    let mut output = Vec::with_capacity(tensors.len());
    for tensor in tensors.into_iter() {
        output.push(tensor.back().await);
    }
    Ok(output)
}