native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
pub struct KvCacheStd {
    pub key: Vec<f32>,
    pub value: Vec<f32>,
    pub max_tokens: usize,
    pub head_dim: usize,
    pub num_heads: usize,
    pub used_tokens: usize,
}

#[derive(Debug)]
pub enum KvCacheError {
    ShapeMismatch,
    CapacityExceeded,
}

impl core::fmt::Display for KvCacheError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            KvCacheError::ShapeMismatch => write!(f, "shape mismatch"),
            KvCacheError::CapacityExceeded => write!(f, "capacity exceeded"),
        }
    }
}

impl std::error::Error for KvCacheError {}

impl KvCacheStd {
    pub fn new(max_tokens: usize, head_dim: usize, num_heads: usize) -> Self {
        let stride = head_dim
            .checked_mul(num_heads)
            .expect("kv_cache: head_dim * num_heads overflow");
        let total = max_tokens
            .checked_mul(stride)
            .expect("kv_cache: max_tokens * stride overflow");
        Self {
            key: vec![0f32; total],
            value: vec![0f32; total],
            max_tokens,
            head_dim,
            num_heads,
            used_tokens: 0,
        }
    }

    fn token_stride(&self) -> Option<usize> {
        self.num_heads.checked_mul(self.head_dim)
    }

    pub fn append_token(
        &mut self,
        key_token: &[f32],
        value_token: &[f32],
    ) -> Result<(), KvCacheError> {
        let mut view = native_neural_network::kv_cache::KvCacheView {
            key: self.key.as_mut_slice(),
            value: self.value.as_mut_slice(),
            max_tokens: self.max_tokens,
            head_dim: self.head_dim,
            num_heads: self.num_heads,
            used_tokens: self.used_tokens,
        };
        match view.append_token(key_token, value_token) {
            Ok(()) => {
                self.used_tokens = view.used_tokens;
                Ok(())
            }
            Err(native_neural_network::kv_cache::KvCacheError::ShapeMismatch) => {
                Err(KvCacheError::ShapeMismatch)
            }
            Err(native_neural_network::kv_cache::KvCacheError::CapacityExceeded) => {
                Err(KvCacheError::CapacityExceeded)
            }
        }
    }

    pub fn len_tokens(&self) -> usize {
        self.used_tokens
    }

    pub fn is_empty(&self) -> bool {
        self.used_tokens == 0
    }

    pub fn is_full(&self) -> bool {
        self.used_tokens >= self.max_tokens
    }

    pub fn remaining_tokens(&self) -> usize {
        self.max_tokens.saturating_sub(self.used_tokens)
    }

    pub fn clear(&mut self) {
        self.used_tokens = 0;
    }

    pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
        let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
        if token_index >= self.used_tokens {
            return Err(KvCacheError::CapacityExceeded);
        }
        let off = token_index
            .checked_mul(stride)
            .ok_or(KvCacheError::ShapeMismatch)?;
        let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
        if end > self.key.len() || end > self.value.len() {
            return Err(KvCacheError::ShapeMismatch);
        }
        Ok((&self.key[off..end], &self.value[off..end]))
    }
}

impl core::fmt::Debug for KvCacheStd {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        f.debug_struct("KvCacheStd")
            .field("max_tokens", &self.max_tokens)
            .field("used_tokens", &self.used_tokens)
            .finish()
    }
}

pub struct KvCacheViewStd<'a> {
    pub key: &'a mut [f32],
    pub value: &'a mut [f32],
    pub max_tokens: usize,
    pub head_dim: usize,
    pub num_heads: usize,
    pub used_tokens: usize,
}

impl<'a> KvCacheViewStd<'a> {
    pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
        if token_index >= self.used_tokens {
            return Err(KvCacheError::CapacityExceeded);
        }
        let stride = self
            .num_heads
            .checked_mul(self.head_dim)
            .ok_or(KvCacheError::ShapeMismatch)?;
        let off = token_index
            .checked_mul(stride)
            .ok_or(KvCacheError::ShapeMismatch)?;
        let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
        if end > self.key.len() || end > self.value.len() {
            return Err(KvCacheError::ShapeMismatch);
        }
        Ok((&self.key[off..end], &self.value[off..end]))
    }
}

pub type KvCacheView<'a> = KvCacheViewStd<'a>;

fn kv_cache_init_hook() -> Result<(), crate::InitError> {
    Ok(())
}

pub fn register_module_init_hooks() -> Result<(), crate::RegisterError> {
    crate::register_init_hook(
        "kv_cache",
        crate::InitSubsystem::KvCache,
        kv_cache_init_hook,
    )
}