native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::KvCacheError;

pub struct KvCacheView<'a> {
    pub key: &'a mut [f32],
    pub value: &'a mut [f32],
    pub max_tokens: usize,
    pub head_dim: usize,
    pub num_heads: usize,
    pub used_tokens: usize,
}

impl<'a> KvCacheView<'a> {
    pub fn token_stride(&self) -> Option<usize> {
        self.num_heads.checked_mul(self.head_dim)
    }

    pub fn append_token(
        &mut self,
        key_token: &[f32],
        value_token: &[f32],
    ) -> Result<(), KvCacheError> {
        let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
        if key_token.len() != stride || value_token.len() != stride {
            return Err(KvCacheError::ShapeMismatch);
        }
        if self.used_tokens >= self.max_tokens {
            return Err(KvCacheError::CapacityExceeded);
        }

        let off = self
            .used_tokens
            .checked_mul(stride)
            .ok_or(KvCacheError::ShapeMismatch)?;
        if off + stride > self.key.len() || off + stride > self.value.len() {
            return Err(KvCacheError::CapacityExceeded);
        }
        self.key[off..off + stride].copy_from_slice(key_token);
        self.value[off..off + stride].copy_from_slice(value_token);
        self.used_tokens += 1;
        Ok(())
    }

    pub fn len_tokens(&self) -> usize {
        self.used_tokens
    }

    pub fn is_empty(&self) -> bool {
        self.used_tokens == 0
    }

    pub fn is_full(&self) -> bool {
        self.used_tokens >= self.max_tokens
    }

    pub fn remaining_tokens(&self) -> usize {
        self.max_tokens.saturating_sub(self.used_tokens)
    }

    pub fn clear(&mut self) {
        self.used_tokens = 0;
    }

    pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
        if token_index >= self.used_tokens {
            return Err(KvCacheError::CapacityExceeded);
        }
        let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
        let off = token_index
            .checked_mul(stride)
            .ok_or(KvCacheError::ShapeMismatch)?;
        let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
        if end > self.key.len() || end > self.value.len() {
            return Err(KvCacheError::ShapeMismatch);
        }
        Ok((&self.key[off..end], &self.value[off..end]))
    }
}