mod architecture;
pub mod cache;
mod config;
mod kv_quantized;
pub mod kv_turboquant;
pub mod deltanet;
pub mod mamba;
pub mod embeddings;
mod error;
pub mod hf_config;
pub mod layers;
mod llama;
pub mod bert;
mod loader;
pub mod lora;
pub mod moe;
pub mod paged;
pub mod speculative;
pub mod turboquant;
pub use architecture::Architecture;
pub use kv_quantized::{KVCacheFormat, QuantizedKVCache};
pub use kv_turboquant::TurboQuantKVCache;
pub use turboquant::TurboQuantConfig;
pub use cache::{
CachedPrefix, PrefixId, PrefixSharing, PromptCache, PromptCacheConfig, PromptCacheStats,
};
pub use config::{ActivationType, AttentionLayerConfig, AttentionLayerType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
pub use embeddings::{
EmbeddingConfig, EmbeddingError, EmbeddingExtractor, PoolingStrategy, TruncationStrategy,
cosine_similarity, dot_product, euclidean_distance, find_nearest,
};
pub use error::{ModelError, ModelResult};
pub use hf_config::{HfConfig, RopeScalingConfig};
pub use deltanet::{
DeltaNetConfig, DeltaNetLayer, DeltaNetState, RecurrentConfig, RecurrentLayerState,
RecurrentState,
};
pub use mamba::{MambaConfig, MambaState, MambaLayer};
pub use bert::{BertLayer, BertModel};
pub use layers::{AttentionLayer, FfnLayer, TransformerLayer};
pub use llama::LlamaModel;
pub use loader::{ModelLoader, ModelSource, build_llama_model, load_llama_model};
pub use lora::{LoraAdapter, LoraAdapters, LoraConfig};
pub use moe::{MoeConfig, MoeExpert, MoeLayer, MoeRouter, MoeStats};
pub use paged::{BlockId, BlockTable, PageAllocator, PagedKVPool, PagedSequence, DEFAULT_BLOCK_SIZE};
pub use speculative::{SpeculativeConfig, SpeculativeDecoder, SpeculativeMode, SpeculativeStats};
use std::sync::Arc;
use crate::backend::Backend;
use crate::tensor::Tensor;
#[derive(Debug)]
pub struct KVCache {
pub k_cache: Vec<Tensor>,
pub v_cache: Vec<Tensor>,
pub seq_len: usize,
pub max_seq_len: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub num_layers: usize,
pub kv_source_layer: Vec<usize>,
}
impl KVCache {
pub fn new(
num_layers: usize,
num_kv_heads: usize,
max_seq_len: usize,
head_dim: usize,
) -> Self {
use crate::tensor::DType;
let k_cache: Vec<Tensor> = (0..num_layers)
.map(|_| Tensor::zeros(vec![num_kv_heads, max_seq_len, head_dim], DType::F32))
.collect();
let v_cache: Vec<Tensor> = (0..num_layers)
.map(|_| Tensor::zeros(vec![num_kv_heads, max_seq_len, head_dim], DType::F32))
.collect();
Self {
k_cache,
v_cache,
seq_len: 0,
max_seq_len,
num_kv_heads,
head_dim,
num_layers,
kv_source_layer: (0..num_layers).collect(),
}
}
pub fn new_heterogeneous(
layer_configs: &[AttentionLayerConfig],
max_seq_len: usize,
kv_source_layer: Vec<usize>,
) -> Self {
use crate::tensor::DType;
let num_layers = layer_configs.len();
let k_cache: Vec<Tensor> = (0..num_layers)
.map(|i| {
if kv_source_layer[i] == i {
let cfg = &layer_configs[i];
Tensor::zeros(
vec![cfg.num_kv_heads, max_seq_len, cfg.head_dim],
DType::F32,
)
} else {
Tensor::zeros(vec![0], DType::F32)
}
})
.collect();
let v_cache: Vec<Tensor> = (0..num_layers)
.map(|i| {
if kv_source_layer[i] == i {
let cfg = &layer_configs[i];
Tensor::zeros(
vec![cfg.num_kv_heads, max_seq_len, cfg.head_dim],
DType::F32,
)
} else {
Tensor::zeros(vec![0], DType::F32)
}
})
.collect();
let first = &layer_configs[0];
Self {
k_cache,
v_cache,
seq_len: 0,
max_seq_len,
num_kv_heads: first.num_kv_heads,
head_dim: first.head_dim,
num_layers,
kv_source_layer,
}
}
pub fn reset(&mut self) {
self.seq_len = 0;
}
pub fn remaining_capacity(&self) -> usize {
self.max_seq_len.saturating_sub(self.seq_len)
}
pub fn is_full(&self) -> bool {
self.seq_len >= self.max_seq_len
}
pub fn truncate(&mut self, new_len: usize) {
if new_len < self.seq_len {
self.seq_len = new_len;
}
}
pub fn shift_left(&mut self, amount: usize) {
if amount == 0 || amount >= self.seq_len {
self.seq_len = 0;
return;
}
let new_len = self.seq_len - amount;
for layer_idx in 0..self.num_layers {
if self.kv_source_layer[layer_idx] != layer_idx {
continue;
}
let shape = self.k_cache[layer_idx].shape();
if shape.len() < 3 {
continue; }
let num_heads = shape[0];
let max_seq = shape[1];
let dim = shape[2];
let row_stride = max_seq * dim;
let copy_elems = new_len * dim;
if let Ok(k_data) = self.k_cache[layer_idx].as_f32_mut() {
for head in 0..num_heads {
let base = head * row_stride;
let src_start = base + amount * dim;
k_data.copy_within(src_start..src_start + copy_elems, base);
}
}
if let Ok(v_data) = self.v_cache[layer_idx].as_f32_mut() {
for head in 0..num_heads {
let base = head * row_stride;
let src_start = base + amount * dim;
v_data.copy_within(src_start..src_start + copy_elems, base);
}
}
}
self.seq_len = new_len;
}
pub fn memory_usage(&self) -> usize {
self.k_cache
.iter()
.chain(self.v_cache.iter())
.map(|t| t.numel() * 4) .sum()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum KVCacheType {
F32,
TurboQuantMSE { bits: u8 },
TurboQuantProd { bits: u8 },
}
impl Default for KVCacheType {
fn default() -> Self {
Self::F32
}
}
impl KVCacheType {
pub fn to_tq_config(&self, dim: usize) -> Option<TurboQuantConfig> {
match *self {
Self::F32 => None,
Self::TurboQuantMSE { bits } => Some(TurboQuantConfig {
bits,
use_qjl: false,
dim,
}),
Self::TurboQuantProd { bits } => Some(TurboQuantConfig {
bits,
use_qjl: true,
dim,
}),
}
}
pub fn is_turboquant(&self) -> bool {
!matches!(self, Self::F32)
}
}
pub struct InferenceContext {
pub kv_cache: KVCache,
pub backend: Arc<dyn Backend>,
pub position: usize,
pub recurrent_state: Option<RecurrentState>,
pub tq_cache: Option<TurboQuantKVCache>,
}
fn build_kv_cache(config: &ModelConfig) -> KVCache {
if let Some(ref layer_configs) = config.attention_layer_configs {
let kv_mapping = config
.kv_source_layer
.clone()
.unwrap_or_else(|| (0..config.num_layers).collect());
KVCache::new_heterogeneous(layer_configs, config.max_seq_len, kv_mapping)
} else {
KVCache::new(
config.num_layers,
config.num_kv_heads,
config.max_seq_len,
config.key_length,
)
}
}
impl InferenceContext {
pub fn new(config: &ModelConfig, backend: Arc<dyn Backend>) -> Self {
Self {
kv_cache: build_kv_cache(config),
backend,
position: 0,
recurrent_state: None,
tq_cache: None,
}
}
pub fn new_with_cache_type(
config: &ModelConfig,
backend: Arc<dyn Backend>,
cache_type: KVCacheType,
) -> Self {
let tq_cache = cache_type
.to_tq_config(config.key_length)
.map(|tq_config| {
TurboQuantKVCache::new(
config.num_layers,
config.num_kv_heads,
config.max_seq_len,
config.key_length,
tq_config,
)
});
Self {
kv_cache: build_kv_cache(config),
backend,
position: 0,
recurrent_state: None,
tq_cache,
}
}
pub fn new_with_recurrent(
config: &ModelConfig,
backend: Arc<dyn Backend>,
is_recurrent: &[bool],
rc: &RecurrentConfig,
) -> Self {
Self {
kv_cache: build_kv_cache(config),
backend,
position: 0,
recurrent_state: Some(RecurrentState::new(
config.num_layers,
is_recurrent,
rc,
)),
tq_cache: None,
}
}
pub fn reset(&mut self) {
self.kv_cache.reset();
self.position = 0;
if let Some(ref mut rs) = self.recurrent_state {
rs.reset();
}
if let Some(ref mut tq) = self.tq_cache {
tq.reset();
}
}
pub fn has_turboquant(&self) -> bool {
self.tq_cache.is_some()
}
}
pub trait Model: Send + Sync {
fn forward(&self, tokens: &[u32], ctx: &mut InferenceContext) -> ModelResult<Tensor>;
fn config(&self) -> &ModelConfig;
fn architecture(&self) -> Architecture;
fn create_context(&self, backend: Arc<dyn Backend>) -> InferenceContext {
InferenceContext::new(self.config(), backend)
}
fn vocab_size(&self) -> usize {
self.config().vocab_size
}
fn max_seq_len(&self) -> usize {
self.config().max_seq_len
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kv_cache_type_default() {
assert_eq!(KVCacheType::default(), KVCacheType::F32);
}
#[test]
fn test_kv_cache_type_is_turboquant() {
assert!(!KVCacheType::F32.is_turboquant());
assert!(KVCacheType::TurboQuantMSE { bits: 2 }.is_turboquant());
assert!(KVCacheType::TurboQuantProd { bits: 3 }.is_turboquant());
}
#[test]
fn test_kv_cache_type_to_tq_config() {
assert!(KVCacheType::F32.to_tq_config(64).is_none());
let cfg = KVCacheType::TurboQuantMSE { bits: 2 }
.to_tq_config(128)
.unwrap();
assert_eq!(cfg.bits, 2);
assert_eq!(cfg.dim, 128);
assert!(!cfg.use_qjl);
let cfg = KVCacheType::TurboQuantProd { bits: 3 }
.to_tq_config(64)
.unwrap();
assert_eq!(cfg.bits, 3);
assert_eq!(cfg.dim, 64);
assert!(cfg.use_qjl);
}
#[test]
fn test_kv_cache_type_serde_roundtrip() {
let types = [
KVCacheType::F32,
KVCacheType::TurboQuantMSE { bits: 2 },
KVCacheType::TurboQuantProd { bits: 3 },
];
for ty in &types {
let json = serde_json::to_string(ty).unwrap();
let parsed: KVCacheType = serde_json::from_str(&json).unwrap();
assert_eq!(*ty, parsed);
}
}
#[test]
fn test_kv_cache_heterogeneous() {
use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
let configs = vec![
AttentionLayerConfig {
layer_type: AttentionLayerType::Sliding,
head_dim: 256,
num_kv_heads: 4,
rope_freq_base: 10000.0,
rope_dims: 256,
sliding_window: 1024,
},
AttentionLayerConfig {
layer_type: AttentionLayerType::Global,
head_dim: 512,
num_kv_heads: 2,
rope_freq_base: 1_000_000.0,
rope_dims: 128,
sliding_window: 0,
},
];
let mapping = vec![0, 1];
let cache = super::KVCache::new_heterogeneous(&configs, 128, mapping);
assert_eq!(cache.k_cache[0].shape(), &[4, 128, 256]);
assert_eq!(cache.v_cache[0].shape(), &[4, 128, 256]);
assert_eq!(cache.k_cache[1].shape(), &[2, 128, 512]);
assert_eq!(cache.v_cache[1].shape(), &[2, 128, 512]);
}
#[test]
fn test_kv_cache_shared_layers() {
use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
let cfg = AttentionLayerConfig {
layer_type: AttentionLayerType::Sliding,
head_dim: 128,
num_kv_heads: 4,
rope_freq_base: 10000.0,
rope_dims: 128,
sliding_window: 1024,
};
let configs = vec![cfg.clone(), cfg.clone(), cfg.clone()];
let mapping = vec![0, 1, 0];
let cache = super::KVCache::new_heterogeneous(&configs, 64, mapping);
assert_eq!(cache.k_cache[0].shape(), &[4, 64, 128]);
assert_eq!(cache.k_cache[1].shape(), &[4, 64, 128]);
assert_eq!(cache.k_cache[2].shape(), &[0]);
assert_eq!(cache.kv_source_layer[2], 0);
}
}