gllm 0.10.6

Pure Rust library for local embeddings, reranking, and text generation with MoE-optimized inference and aggressive performance tuning
Documentation
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;

pub struct KVCache<B: Backend> {
    keys: Vec<Tensor<B, 4>>,
    values: Vec<Tensor<B, 4>>,
    num_layers: usize,
    max_len: usize,
    current_len: usize,
}

impl<B: Backend> KVCache<B> {
    pub fn preallocate(
        num_layers: usize,
        max_len: usize,
        batch_size: usize,
        num_heads: usize,
        head_dim: usize,
        device: &B::Device,
    ) -> Self {
        let mut keys = Vec::with_capacity(num_layers);
        let mut values = Vec::with_capacity(num_layers);
        for _ in 0..num_layers {
            keys.push(Tensor::zeros([batch_size, num_heads, max_len, head_dim], device));
            values.push(Tensor::zeros([batch_size, num_heads, max_len, head_dim], device));
        }

        Self {
            keys,
            values,
            num_layers,
            max_len,
            current_len: 0,
        }
    }

    pub fn update(
        &mut self,
        layer: usize,
        new_k: Tensor<B, 4>,
        new_v: Tensor<B, 4>,
    ) -> (Tensor<B, 4>, Tensor<B, 4>) {
        if layer >= self.num_layers {
            return (new_k, new_v);
        }

        let [batch_size, num_heads, seq_len, head_dim] = new_k.dims();
        if self.max_len == 0 {
            return (new_k, new_v);
        }

        let start = self.current_len;
        let end = start + seq_len;
        assert!(
            end <= self.max_len,
            "KV cache exceeded max_len ({} > {})",
            end,
            self.max_len
        );

        let updated_k = self.keys[layer].clone().slice_assign(
            [0..batch_size, 0..num_heads, start..end, 0..head_dim],
            new_k,
        );
        let updated_v = self.values[layer].clone().slice_assign(
            [0..batch_size, 0..num_heads, start..end, 0..head_dim],
            new_v,
        );

        self.keys[layer] = updated_k;
        self.values[layer] = updated_v;

        if layer + 1 == self.num_layers {
            self.current_len = end;
        }

        let cached_k = self.keys[layer]
            .clone()
            .slice([0..batch_size, 0..num_heads, 0..end, 0..head_dim]);
        let cached_v = self.values[layer]
            .clone()
            .slice([0..batch_size, 0..num_heads, 0..end, 0..head_dim]);

        (cached_k, cached_v)
    }

    pub fn seq_len(&self) -> usize {
        self.current_len
    }
}