1mod architecture;
12pub mod cache;
13mod config;
14mod kv_quantized;
15pub mod kv_turboquant;
16pub mod deltanet;
17pub mod mamba;
18pub mod embeddings;
19mod error;
20pub mod hf_config;
21pub mod layers;
22mod llama;
23pub mod bert;
24mod loader;
25pub mod lora;
26pub mod moe;
27pub mod paged;
28pub mod speculative;
29pub mod turboquant;
30
31pub use architecture::Architecture;
32pub use kv_quantized::{KVCacheFormat, QuantizedKVCache};
33pub use kv_turboquant::TurboQuantKVCache;
34pub use turboquant::TurboQuantConfig;
35pub use cache::{
36 CachedPrefix, PrefixId, PrefixSharing, PromptCache, PromptCacheConfig, PromptCacheStats,
37};
38pub use config::{ActivationType, AttentionLayerConfig, AttentionLayerType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
39pub use embeddings::{
40 EmbeddingConfig, EmbeddingError, EmbeddingExtractor, PoolingStrategy, TruncationStrategy,
41 cosine_similarity, dot_product, euclidean_distance, find_nearest,
42};
43pub use error::{ModelError, ModelResult};
44pub use hf_config::{HfConfig, RopeScalingConfig};
45pub use deltanet::{
46 DeltaNetConfig, DeltaNetLayer, DeltaNetState, RecurrentConfig, RecurrentLayerState,
47 RecurrentState,
48};
49pub use mamba::{MambaConfig, MambaState, MambaLayer};
50pub use bert::{BertLayer, BertModel};
51pub use layers::{AttentionLayer, FfnLayer, TransformerLayer};
52pub use llama::LlamaModel;
53pub use loader::{ModelLoader, ModelSource, build_llama_model, load_llama_model};
54pub use lora::{LoraAdapter, LoraAdapters, LoraConfig};
55pub use moe::{MoeConfig, MoeExpert, MoeLayer, MoeRouter, MoeStats};
56pub use paged::{BlockId, BlockTable, PageAllocator, PagedKVPool, PagedSequence, DEFAULT_BLOCK_SIZE};
57pub use speculative::{SpeculativeConfig, SpeculativeDecoder, SpeculativeMode, SpeculativeStats};
58
59use std::sync::Arc;
60
61use crate::backend::Backend;
62use crate::tensor::Tensor;
63
64#[derive(Debug)]
66pub struct KVCache {
67 pub k_cache: Vec<Tensor>,
69 pub v_cache: Vec<Tensor>,
71 pub seq_len: usize,
73 pub max_seq_len: usize,
75 pub num_kv_heads: usize,
77 pub head_dim: usize,
79 pub num_layers: usize,
81 pub kv_source_layer: Vec<usize>,
83}
84
85impl KVCache {
86 pub fn new(
88 num_layers: usize,
89 num_kv_heads: usize,
90 max_seq_len: usize,
91 head_dim: usize,
92 ) -> Self {
93 use crate::tensor::DType;
94
95 let k_cache: Vec<Tensor> = (0..num_layers)
96 .map(|_| Tensor::zeros(vec![num_kv_heads, max_seq_len, head_dim], DType::F32))
97 .collect();
98
99 let v_cache: Vec<Tensor> = (0..num_layers)
100 .map(|_| Tensor::zeros(vec![num_kv_heads, max_seq_len, head_dim], DType::F32))
101 .collect();
102
103 Self {
104 k_cache,
105 v_cache,
106 seq_len: 0,
107 max_seq_len,
108 num_kv_heads,
109 head_dim,
110 num_layers,
111 kv_source_layer: (0..num_layers).collect(),
112 }
113 }
114
115 pub fn new_heterogeneous(
117 layer_configs: &[AttentionLayerConfig],
118 max_seq_len: usize,
119 kv_source_layer: Vec<usize>,
120 ) -> Self {
121 use crate::tensor::DType;
122
123 let num_layers = layer_configs.len();
124
125 let k_cache: Vec<Tensor> = (0..num_layers)
128 .map(|i| {
129 if kv_source_layer[i] == i {
130 let cfg = &layer_configs[i];
131 Tensor::zeros(
132 vec![cfg.num_kv_heads, max_seq_len, cfg.head_dim],
133 DType::F32,
134 )
135 } else {
136 Tensor::zeros(vec![0], DType::F32)
138 }
139 })
140 .collect();
141
142 let v_cache: Vec<Tensor> = (0..num_layers)
143 .map(|i| {
144 if kv_source_layer[i] == i {
145 let cfg = &layer_configs[i];
146 Tensor::zeros(
147 vec![cfg.num_kv_heads, max_seq_len, cfg.head_dim],
148 DType::F32,
149 )
150 } else {
151 Tensor::zeros(vec![0], DType::F32)
152 }
153 })
154 .collect();
155
156 let first = &layer_configs[0];
158 Self {
159 k_cache,
160 v_cache,
161 seq_len: 0,
162 max_seq_len,
163 num_kv_heads: first.num_kv_heads,
164 head_dim: first.head_dim,
165 num_layers,
166 kv_source_layer,
167 }
168 }
169
170 pub fn reset(&mut self) {
176 self.seq_len = 0;
177 }
178
179 pub fn remaining_capacity(&self) -> usize {
181 self.max_seq_len.saturating_sub(self.seq_len)
182 }
183
184 pub fn is_full(&self) -> bool {
186 self.seq_len >= self.max_seq_len
187 }
188
189 pub fn truncate(&mut self, new_len: usize) {
191 if new_len < self.seq_len {
192 self.seq_len = new_len;
193 }
194 }
195
196 pub fn shift_left(&mut self, amount: usize) {
203 if amount == 0 || amount >= self.seq_len {
204 self.seq_len = 0;
205 return;
206 }
207
208 let new_len = self.seq_len - amount;
209
210 for layer_idx in 0..self.num_layers {
211 if self.kv_source_layer[layer_idx] != layer_idx {
213 continue;
214 }
215
216 let shape = self.k_cache[layer_idx].shape();
217 if shape.len() < 3 {
218 continue; }
220 let num_heads = shape[0];
221 let max_seq = shape[1];
222 let dim = shape[2];
223 let row_stride = max_seq * dim;
224 let copy_elems = new_len * dim;
225
226 if let Ok(k_data) = self.k_cache[layer_idx].as_f32_mut() {
227 for head in 0..num_heads {
228 let base = head * row_stride;
229 let src_start = base + amount * dim;
230 k_data.copy_within(src_start..src_start + copy_elems, base);
231 }
232 }
233
234 if let Ok(v_data) = self.v_cache[layer_idx].as_f32_mut() {
235 for head in 0..num_heads {
236 let base = head * row_stride;
237 let src_start = base + amount * dim;
238 v_data.copy_within(src_start..src_start + copy_elems, base);
239 }
240 }
241 }
242
243 self.seq_len = new_len;
244 }
245
246 pub fn memory_usage(&self) -> usize {
248 self.k_cache
249 .iter()
250 .chain(self.v_cache.iter())
251 .map(|t| t.numel() * 4) .sum()
253 }
254}
255
256#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
258pub enum KVCacheType {
259 F32,
261 TurboQuantMSE { bits: u8 },
263 TurboQuantProd { bits: u8 },
265}
266
267impl Default for KVCacheType {
268 fn default() -> Self {
269 Self::F32
270 }
271}
272
273impl KVCacheType {
274 pub fn to_tq_config(&self, dim: usize) -> Option<TurboQuantConfig> {
276 match *self {
277 Self::F32 => None,
278 Self::TurboQuantMSE { bits } => Some(TurboQuantConfig {
279 bits,
280 use_qjl: false,
281 dim,
282 }),
283 Self::TurboQuantProd { bits } => Some(TurboQuantConfig {
284 bits,
285 use_qjl: true,
286 dim,
287 }),
288 }
289 }
290
291 pub fn is_turboquant(&self) -> bool {
293 !matches!(self, Self::F32)
294 }
295}
296
297pub struct InferenceContext {
299 pub kv_cache: KVCache,
301 pub backend: Arc<dyn Backend>,
303 pub position: usize,
305 pub recurrent_state: Option<RecurrentState>,
307 pub tq_cache: Option<TurboQuantKVCache>,
309}
310
311fn build_kv_cache(config: &ModelConfig) -> KVCache {
312 if let Some(ref layer_configs) = config.attention_layer_configs {
313 let kv_mapping = config
314 .kv_source_layer
315 .clone()
316 .unwrap_or_else(|| (0..config.num_layers).collect());
317 KVCache::new_heterogeneous(layer_configs, config.max_seq_len, kv_mapping)
318 } else {
319 KVCache::new(
320 config.num_layers,
321 config.num_kv_heads,
322 config.max_seq_len,
323 config.key_length,
324 )
325 }
326}
327
328impl InferenceContext {
329 pub fn new(config: &ModelConfig, backend: Arc<dyn Backend>) -> Self {
331 Self {
332 kv_cache: build_kv_cache(config),
333 backend,
334 position: 0,
335 recurrent_state: None,
336 tq_cache: None,
337 }
338 }
339
340 pub fn new_with_cache_type(
342 config: &ModelConfig,
343 backend: Arc<dyn Backend>,
344 cache_type: KVCacheType,
345 ) -> Self {
346 let tq_cache = cache_type
347 .to_tq_config(config.key_length)
348 .map(|tq_config| {
349 TurboQuantKVCache::new(
350 config.num_layers,
351 config.num_kv_heads,
352 config.max_seq_len,
353 config.key_length,
354 tq_config,
355 )
356 });
357
358 Self {
359 kv_cache: build_kv_cache(config),
360 backend,
361 position: 0,
362 recurrent_state: None,
363 tq_cache,
364 }
365 }
366
367 pub fn new_with_recurrent(
370 config: &ModelConfig,
371 backend: Arc<dyn Backend>,
372 is_recurrent: &[bool],
373 rc: &RecurrentConfig,
374 ) -> Self {
375 Self {
376 kv_cache: build_kv_cache(config),
377 backend,
378 position: 0,
379 recurrent_state: Some(RecurrentState::new(
380 config.num_layers,
381 is_recurrent,
382 rc,
383 )),
384 tq_cache: None,
385 }
386 }
387
388 pub fn reset(&mut self) {
390 self.kv_cache.reset();
391 self.position = 0;
392 if let Some(ref mut rs) = self.recurrent_state {
393 rs.reset();
394 }
395 if let Some(ref mut tq) = self.tq_cache {
396 tq.reset();
397 }
398 }
399
400 pub fn has_turboquant(&self) -> bool {
402 self.tq_cache.is_some()
403 }
404}
405
406pub trait Model: Send + Sync {
408 fn forward(&self, tokens: &[u32], ctx: &mut InferenceContext) -> ModelResult<Tensor>;
417
418 fn config(&self) -> &ModelConfig;
420
421 fn architecture(&self) -> Architecture;
423
424 fn create_context(&self, backend: Arc<dyn Backend>) -> InferenceContext {
426 InferenceContext::new(self.config(), backend)
427 }
428
429 fn vocab_size(&self) -> usize {
431 self.config().vocab_size
432 }
433
434 fn max_seq_len(&self) -> usize {
436 self.config().max_seq_len
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_kv_cache_type_default() {
446 assert_eq!(KVCacheType::default(), KVCacheType::F32);
447 }
448
449 #[test]
450 fn test_kv_cache_type_is_turboquant() {
451 assert!(!KVCacheType::F32.is_turboquant());
452 assert!(KVCacheType::TurboQuantMSE { bits: 2 }.is_turboquant());
453 assert!(KVCacheType::TurboQuantProd { bits: 3 }.is_turboquant());
454 }
455
456 #[test]
457 fn test_kv_cache_type_to_tq_config() {
458 assert!(KVCacheType::F32.to_tq_config(64).is_none());
459
460 let cfg = KVCacheType::TurboQuantMSE { bits: 2 }
461 .to_tq_config(128)
462 .unwrap();
463 assert_eq!(cfg.bits, 2);
464 assert_eq!(cfg.dim, 128);
465 assert!(!cfg.use_qjl);
466
467 let cfg = KVCacheType::TurboQuantProd { bits: 3 }
468 .to_tq_config(64)
469 .unwrap();
470 assert_eq!(cfg.bits, 3);
471 assert_eq!(cfg.dim, 64);
472 assert!(cfg.use_qjl);
473 }
474
475 #[test]
476 fn test_kv_cache_type_serde_roundtrip() {
477 let types = [
478 KVCacheType::F32,
479 KVCacheType::TurboQuantMSE { bits: 2 },
480 KVCacheType::TurboQuantProd { bits: 3 },
481 ];
482 for ty in &types {
483 let json = serde_json::to_string(ty).unwrap();
484 let parsed: KVCacheType = serde_json::from_str(&json).unwrap();
485 assert_eq!(*ty, parsed);
486 }
487 }
488
489 #[test]
490 fn test_kv_cache_heterogeneous() {
491 use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
492
493 let configs = vec![
494 AttentionLayerConfig {
495 layer_type: AttentionLayerType::Sliding,
496 head_dim: 256,
497 num_kv_heads: 4,
498 rope_freq_base: 10000.0,
499 rope_dims: 256,
500 sliding_window: 1024,
501 },
502 AttentionLayerConfig {
503 layer_type: AttentionLayerType::Global,
504 head_dim: 512,
505 num_kv_heads: 2,
506 rope_freq_base: 1_000_000.0,
507 rope_dims: 128,
508 sliding_window: 0,
509 },
510 ];
511 let mapping = vec![0, 1];
512 let cache = super::KVCache::new_heterogeneous(&configs, 128, mapping);
513
514 assert_eq!(cache.k_cache[0].shape(), &[4, 128, 256]);
515 assert_eq!(cache.v_cache[0].shape(), &[4, 128, 256]);
516 assert_eq!(cache.k_cache[1].shape(), &[2, 128, 512]);
517 assert_eq!(cache.v_cache[1].shape(), &[2, 128, 512]);
518 }
519
520 #[test]
521 fn test_kv_cache_shared_layers() {
522 use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
523
524 let cfg = AttentionLayerConfig {
525 layer_type: AttentionLayerType::Sliding,
526 head_dim: 128,
527 num_kv_heads: 4,
528 rope_freq_base: 10000.0,
529 rope_dims: 128,
530 sliding_window: 1024,
531 };
532 let configs = vec![cfg.clone(), cfg.clone(), cfg.clone()];
533 let mapping = vec![0, 1, 0];
534 let cache = super::KVCache::new_heterogeneous(&configs, 64, mapping);
535
536 assert_eq!(cache.k_cache[0].shape(), &[4, 64, 128]);
537 assert_eq!(cache.k_cache[1].shape(), &[4, 64, 128]);
538 assert_eq!(cache.k_cache[2].shape(), &[0]);
539 assert_eq!(cache.kv_source_layer[2], 0);
540 }
541}