hanzo-nn 0.10.2

Minimalist ML framework.
Documentation
pub mod cpu_flash;
pub mod varlen;

use hanzo_ml::Tensor;

pub use cpu_flash::flash_attn;
pub use cpu_flash::varlen::flash_attn_varlen_cpu;
pub use varlen::flash_attn_varlen_unfused;

#[derive(Debug, Clone, Default)]
pub enum AttnMask {
    #[default]
    None,
    Causal {
        kv_offset: usize,
    },
    Mask(Tensor),
}

impl AttnMask {
    #[inline]
    pub fn causal() -> Self {
        AttnMask::Causal { kv_offset: 0 }
    }

    #[inline]
    pub fn causal_with_offset(kv_offset: usize) -> Self {
        AttnMask::Causal { kv_offset }
    }

    #[inline]
    pub fn is_causal(&self) -> bool {
        matches!(self, AttnMask::Causal { .. })
    }

    #[inline]
    pub fn kv_offset(&self) -> usize {
        match self {
            AttnMask::Causal { kv_offset } => *kv_offset,
            _ => 0,
        }
    }
}