Skip to main content

llama_rs/model/
mod.rs

1//! Model architectures and inference
2//!
3//! This module provides:
4//! - Model configuration types
5//! - Architecture definitions
6//! - The `Model` trait for inference
7//! - LLaMA and other model implementations
8//! - Model loading from GGUF files
9//! - Speculative decoding
10
11mod architecture;
12pub mod cache;
13mod config;
14mod kv_quantized;
15pub mod kv_turboquant;
16pub mod deltanet;
17pub mod mamba;
18pub mod embeddings;
19mod error;
20pub mod hf_config;
21pub mod layers;
22mod llama;
23pub mod bert;
24mod loader;
25pub mod lora;
26pub mod moe;
27pub mod paged;
28pub mod speculative;
29pub mod turboquant;
30
31pub use architecture::Architecture;
32pub use kv_quantized::{KVCacheFormat, QuantizedKVCache};
33pub use kv_turboquant::TurboQuantKVCache;
34pub use turboquant::TurboQuantConfig;
35pub use cache::{
36    CachedPrefix, PrefixId, PrefixSharing, PromptCache, PromptCacheConfig, PromptCacheStats,
37};
38pub use config::{ActivationType, AttentionLayerConfig, AttentionLayerType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
39pub use embeddings::{
40    EmbeddingConfig, EmbeddingError, EmbeddingExtractor, PoolingStrategy, TruncationStrategy,
41    cosine_similarity, dot_product, euclidean_distance, find_nearest,
42};
43pub use error::{ModelError, ModelResult};
44pub use hf_config::{HfConfig, RopeScalingConfig};
45pub use deltanet::{
46    DeltaNetConfig, DeltaNetLayer, DeltaNetState, RecurrentConfig, RecurrentLayerState,
47    RecurrentState,
48};
49pub use mamba::{MambaConfig, MambaState, MambaLayer};
50pub use bert::{BertLayer, BertModel};
51pub use layers::{AttentionLayer, FfnLayer, TransformerLayer};
52pub use llama::LlamaModel;
53pub use loader::{ModelLoader, ModelSource, build_llama_model, load_llama_model};
54pub use lora::{LoraAdapter, LoraAdapters, LoraConfig};
55pub use moe::{MoeConfig, MoeExpert, MoeLayer, MoeRouter, MoeStats};
56pub use paged::{BlockId, BlockTable, PageAllocator, PagedKVPool, PagedSequence, DEFAULT_BLOCK_SIZE};
57pub use speculative::{SpeculativeConfig, SpeculativeDecoder, SpeculativeMode, SpeculativeStats};
58
59use std::sync::Arc;
60
61use crate::backend::Backend;
62use crate::tensor::Tensor;
63
64/// KV cache for efficient autoregressive generation
65#[derive(Debug)]
66pub struct KVCache {
67    /// Key cache for each layer: [num_kv_heads, max_seq_len, head_dim]
68    pub k_cache: Vec<Tensor>,
69    /// Value cache for each layer
70    pub v_cache: Vec<Tensor>,
71    /// Current sequence length in cache
72    pub seq_len: usize,
73    /// Maximum sequence length
74    pub max_seq_len: usize,
75    /// Number of KV heads
76    pub num_kv_heads: usize,
77    /// Head dimension
78    pub head_dim: usize,
79    /// Number of layers
80    pub num_layers: usize,
81    /// Maps layer index to physical KV cache slot (identity by default).
82    pub kv_source_layer: Vec<usize>,
83}
84
85impl KVCache {
86    /// Create a new KV cache
87    pub fn new(
88        num_layers: usize,
89        num_kv_heads: usize,
90        max_seq_len: usize,
91        head_dim: usize,
92    ) -> Self {
93        use crate::tensor::DType;
94
95        let k_cache: Vec<Tensor> = (0..num_layers)
96            .map(|_| Tensor::zeros(vec![num_kv_heads, max_seq_len, head_dim], DType::F32))
97            .collect();
98
99        let v_cache: Vec<Tensor> = (0..num_layers)
100            .map(|_| Tensor::zeros(vec![num_kv_heads, max_seq_len, head_dim], DType::F32))
101            .collect();
102
103        Self {
104            k_cache,
105            v_cache,
106            seq_len: 0,
107            max_seq_len,
108            num_kv_heads,
109            head_dim,
110            num_layers,
111            kv_source_layer: (0..num_layers).collect(),
112        }
113    }
114
115    /// Create a KV cache with per-layer dimensions and optional KV sharing.
116    pub fn new_heterogeneous(
117        layer_configs: &[AttentionLayerConfig],
118        max_seq_len: usize,
119        kv_source_layer: Vec<usize>,
120    ) -> Self {
121        use crate::tensor::DType;
122
123        let num_layers = layer_configs.len();
124
125        // Only allocate cache for layers that own their slot (kv_source_layer[i] == i).
126        // Shared layers will index into another layer's cache.
127        let k_cache: Vec<Tensor> = (0..num_layers)
128            .map(|i| {
129                if kv_source_layer[i] == i {
130                    let cfg = &layer_configs[i];
131                    Tensor::zeros(
132                        vec![cfg.num_kv_heads, max_seq_len, cfg.head_dim],
133                        DType::F32,
134                    )
135                } else {
136                    // Placeholder — this layer uses another layer's cache
137                    Tensor::zeros(vec![0], DType::F32)
138                }
139            })
140            .collect();
141
142        let v_cache: Vec<Tensor> = (0..num_layers)
143            .map(|i| {
144                if kv_source_layer[i] == i {
145                    let cfg = &layer_configs[i];
146                    Tensor::zeros(
147                        vec![cfg.num_kv_heads, max_seq_len, cfg.head_dim],
148                        DType::F32,
149                    )
150                } else {
151                    Tensor::zeros(vec![0], DType::F32)
152                }
153            })
154            .collect();
155
156        // Use first layer's config as the default for methods that need uniform params
157        let first = &layer_configs[0];
158        Self {
159            k_cache,
160            v_cache,
161            seq_len: 0,
162            max_seq_len,
163            num_kv_heads: first.num_kv_heads,
164            head_dim: first.head_dim,
165            num_layers,
166            kv_source_layer,
167        }
168    }
169
170    /// Reset the cache for a new sequence.
171    ///
172    /// Only resets the position counter. Cache data is not zeroed because
173    /// `attention_cached` only reads positions `0..seq_len`, so stale data
174    /// beyond `seq_len` is never accessed.
175    pub fn reset(&mut self) {
176        self.seq_len = 0;
177    }
178
179    /// Get remaining capacity
180    pub fn remaining_capacity(&self) -> usize {
181        self.max_seq_len.saturating_sub(self.seq_len)
182    }
183
184    /// Check if cache is full
185    pub fn is_full(&self) -> bool {
186        self.seq_len >= self.max_seq_len
187    }
188
189    /// Truncate cache to a specific length (for context shifting)
190    pub fn truncate(&mut self, new_len: usize) {
191        if new_len < self.seq_len {
192            self.seq_len = new_len;
193        }
194    }
195
196    /// Shift cache left by `amount` positions (for sliding window).
197    /// Keeps the last `(seq_len - amount)` positions.
198    ///
199    /// Uses `copy_within` for each head's contiguous run, which compiles to
200    /// a single `memmove` — dramatically faster than the element-wise loop
201    /// it replaces (especially for long sequences).
202    pub fn shift_left(&mut self, amount: usize) {
203        if amount == 0 || amount >= self.seq_len {
204            self.seq_len = 0;
205            return;
206        }
207
208        let new_len = self.seq_len - amount;
209
210        for layer_idx in 0..self.num_layers {
211            // Skip shared layers — their anchor is shifted in its own iteration
212            if self.kv_source_layer[layer_idx] != layer_idx {
213                continue;
214            }
215
216            let shape = self.k_cache[layer_idx].shape();
217            if shape.len() < 3 {
218                continue; // placeholder tensor
219            }
220            let num_heads = shape[0];
221            let max_seq = shape[1];
222            let dim = shape[2];
223            let row_stride = max_seq * dim;
224            let copy_elems = new_len * dim;
225
226            if let Ok(k_data) = self.k_cache[layer_idx].as_f32_mut() {
227                for head in 0..num_heads {
228                    let base = head * row_stride;
229                    let src_start = base + amount * dim;
230                    k_data.copy_within(src_start..src_start + copy_elems, base);
231                }
232            }
233
234            if let Ok(v_data) = self.v_cache[layer_idx].as_f32_mut() {
235                for head in 0..num_heads {
236                    let base = head * row_stride;
237                    let src_start = base + amount * dim;
238                    v_data.copy_within(src_start..src_start + copy_elems, base);
239                }
240            }
241        }
242
243        self.seq_len = new_len;
244    }
245
246    /// Get memory usage in bytes
247    pub fn memory_usage(&self) -> usize {
248        self.k_cache
249            .iter()
250            .chain(self.v_cache.iter())
251            .map(|t| t.numel() * 4) // f32 = 4 bytes
252            .sum()
253    }
254}
255
256/// Which KV cache implementation to use.
257#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
258pub enum KVCacheType {
259    /// Standard f32 KV cache (default).
260    F32,
261    /// TurboQuant MSE: Hadamard rotation + scalar quantization (biased, lower overhead).
262    TurboQuantMSE { bits: u8 },
263    /// TurboQuant prod: MSE + QJL correction (unbiased, higher accuracy).
264    TurboQuantProd { bits: u8 },
265}
266
267impl Default for KVCacheType {
268    fn default() -> Self {
269        Self::F32
270    }
271}
272
273impl KVCacheType {
274    /// Convert to a `TurboQuantConfig` if this is a TurboQuant variant.
275    pub fn to_tq_config(&self, dim: usize) -> Option<TurboQuantConfig> {
276        match *self {
277            Self::F32 => None,
278            Self::TurboQuantMSE { bits } => Some(TurboQuantConfig {
279                bits,
280                use_qjl: false,
281                dim,
282            }),
283            Self::TurboQuantProd { bits } => Some(TurboQuantConfig {
284                bits,
285                use_qjl: true,
286                dim,
287            }),
288        }
289    }
290
291    /// Whether this is any TurboQuant variant.
292    pub fn is_turboquant(&self) -> bool {
293        !matches!(self, Self::F32)
294    }
295}
296
297/// Context for model inference
298pub struct InferenceContext {
299    /// KV cache for attention
300    pub kv_cache: KVCache,
301    /// Backend to use for computation
302    pub backend: Arc<dyn Backend>,
303    /// Current position in sequence
304    pub position: usize,
305    /// Recurrent state for delta-net layers (None if model has no SSM layers)
306    pub recurrent_state: Option<RecurrentState>,
307    /// Optional TurboQuant-compressed KV cache (replaces f32 cache for attention)
308    pub tq_cache: Option<TurboQuantKVCache>,
309}
310
311fn build_kv_cache(config: &ModelConfig) -> KVCache {
312    if let Some(ref layer_configs) = config.attention_layer_configs {
313        let kv_mapping = config
314            .kv_source_layer
315            .clone()
316            .unwrap_or_else(|| (0..config.num_layers).collect());
317        KVCache::new_heterogeneous(layer_configs, config.max_seq_len, kv_mapping)
318    } else {
319        KVCache::new(
320            config.num_layers,
321            config.num_kv_heads,
322            config.max_seq_len,
323            config.key_length,
324        )
325    }
326}
327
328impl InferenceContext {
329    /// Create a new inference context
330    pub fn new(config: &ModelConfig, backend: Arc<dyn Backend>) -> Self {
331        Self {
332            kv_cache: build_kv_cache(config),
333            backend,
334            position: 0,
335            recurrent_state: None,
336            tq_cache: None,
337        }
338    }
339
340    /// Create inference context with a specific KV cache type.
341    pub fn new_with_cache_type(
342        config: &ModelConfig,
343        backend: Arc<dyn Backend>,
344        cache_type: KVCacheType,
345    ) -> Self {
346        let tq_cache = cache_type
347            .to_tq_config(config.key_length)
348            .map(|tq_config| {
349                TurboQuantKVCache::new(
350                    config.num_layers,
351                    config.num_kv_heads,
352                    config.max_seq_len,
353                    config.key_length,
354                    tq_config,
355                )
356            });
357
358        Self {
359            kv_cache: build_kv_cache(config),
360            backend,
361            position: 0,
362            recurrent_state: None,
363            tq_cache,
364        }
365    }
366
367    /// Create inference context with recurrent state for SSM layers.
368    /// `is_recurrent[i]` marks which layers are recurrent (DeltaNet or Mamba).
369    pub fn new_with_recurrent(
370        config: &ModelConfig,
371        backend: Arc<dyn Backend>,
372        is_recurrent: &[bool],
373        rc: &RecurrentConfig,
374    ) -> Self {
375        Self {
376            kv_cache: build_kv_cache(config),
377            backend,
378            position: 0,
379            recurrent_state: Some(RecurrentState::new(
380                config.num_layers,
381                is_recurrent,
382                rc,
383            )),
384            tq_cache: None,
385        }
386    }
387
388    /// Reset context for a new sequence
389    pub fn reset(&mut self) {
390        self.kv_cache.reset();
391        self.position = 0;
392        if let Some(ref mut rs) = self.recurrent_state {
393            rs.reset();
394        }
395        if let Some(ref mut tq) = self.tq_cache {
396            tq.reset();
397        }
398    }
399
400    /// Whether TurboQuant KV cache is active.
401    pub fn has_turboquant(&self) -> bool {
402        self.tq_cache.is_some()
403    }
404}
405
406/// Trait for language models
407pub trait Model: Send + Sync {
408    /// Run forward pass and return logits
409    ///
410    /// # Arguments
411    /// * `tokens` - Input token IDs
412    /// * `ctx` - Inference context with KV cache
413    ///
414    /// # Returns
415    /// Logits tensor of shape [batch_size, vocab_size] or [batch_size, seq_len, vocab_size]
416    fn forward(&self, tokens: &[u32], ctx: &mut InferenceContext) -> ModelResult<Tensor>;
417
418    /// Get model configuration
419    fn config(&self) -> &ModelConfig;
420
421    /// Get model architecture
422    fn architecture(&self) -> Architecture;
423
424    /// Create an InferenceContext with the right state for this model.
425    fn create_context(&self, backend: Arc<dyn Backend>) -> InferenceContext {
426        InferenceContext::new(self.config(), backend)
427    }
428
429    /// Get vocabulary size
430    fn vocab_size(&self) -> usize {
431        self.config().vocab_size
432    }
433
434    /// Get maximum sequence length
435    fn max_seq_len(&self) -> usize {
436        self.config().max_seq_len
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_kv_cache_type_default() {
446        assert_eq!(KVCacheType::default(), KVCacheType::F32);
447    }
448
449    #[test]
450    fn test_kv_cache_type_is_turboquant() {
451        assert!(!KVCacheType::F32.is_turboquant());
452        assert!(KVCacheType::TurboQuantMSE { bits: 2 }.is_turboquant());
453        assert!(KVCacheType::TurboQuantProd { bits: 3 }.is_turboquant());
454    }
455
456    #[test]
457    fn test_kv_cache_type_to_tq_config() {
458        assert!(KVCacheType::F32.to_tq_config(64).is_none());
459
460        let cfg = KVCacheType::TurboQuantMSE { bits: 2 }
461            .to_tq_config(128)
462            .unwrap();
463        assert_eq!(cfg.bits, 2);
464        assert_eq!(cfg.dim, 128);
465        assert!(!cfg.use_qjl);
466
467        let cfg = KVCacheType::TurboQuantProd { bits: 3 }
468            .to_tq_config(64)
469            .unwrap();
470        assert_eq!(cfg.bits, 3);
471        assert_eq!(cfg.dim, 64);
472        assert!(cfg.use_qjl);
473    }
474
475    #[test]
476    fn test_kv_cache_type_serde_roundtrip() {
477        let types = [
478            KVCacheType::F32,
479            KVCacheType::TurboQuantMSE { bits: 2 },
480            KVCacheType::TurboQuantProd { bits: 3 },
481        ];
482        for ty in &types {
483            let json = serde_json::to_string(ty).unwrap();
484            let parsed: KVCacheType = serde_json::from_str(&json).unwrap();
485            assert_eq!(*ty, parsed);
486        }
487    }
488
489    #[test]
490    fn test_kv_cache_heterogeneous() {
491        use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
492
493        let configs = vec![
494            AttentionLayerConfig {
495                layer_type: AttentionLayerType::Sliding,
496                head_dim: 256,
497                num_kv_heads: 4,
498                rope_freq_base: 10000.0,
499                rope_dims: 256,
500                sliding_window: 1024,
501            },
502            AttentionLayerConfig {
503                layer_type: AttentionLayerType::Global,
504                head_dim: 512,
505                num_kv_heads: 2,
506                rope_freq_base: 1_000_000.0,
507                rope_dims: 128,
508                sliding_window: 0,
509            },
510        ];
511        let mapping = vec![0, 1];
512        let cache = super::KVCache::new_heterogeneous(&configs, 128, mapping);
513
514        assert_eq!(cache.k_cache[0].shape(), &[4, 128, 256]);
515        assert_eq!(cache.v_cache[0].shape(), &[4, 128, 256]);
516        assert_eq!(cache.k_cache[1].shape(), &[2, 128, 512]);
517        assert_eq!(cache.v_cache[1].shape(), &[2, 128, 512]);
518    }
519
520    #[test]
521    fn test_kv_cache_shared_layers() {
522        use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
523
524        let cfg = AttentionLayerConfig {
525            layer_type: AttentionLayerType::Sliding,
526            head_dim: 128,
527            num_kv_heads: 4,
528            rope_freq_base: 10000.0,
529            rope_dims: 128,
530            sliding_window: 1024,
531        };
532        let configs = vec![cfg.clone(), cfg.clone(), cfg.clone()];
533        let mapping = vec![0, 1, 0];
534        let cache = super::KVCache::new_heterogeneous(&configs, 64, mapping);
535
536        assert_eq!(cache.k_cache[0].shape(), &[4, 64, 128]);
537        assert_eq!(cache.k_cache[1].shape(), &[4, 64, 128]);
538        assert_eq!(cache.k_cache[2].shape(), &[0]);
539        assert_eq!(cache.kv_source_layer[2], 0);
540    }
541}