Skip to main content

llama_rs/model/
config.rs

1//! Model configuration types
2
3use serde::{Deserialize, Serialize};
4
5/// RoPE implementation type
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
7pub enum RopeType {
8    /// Normal/LLaMA style: consecutive pairs (x[2i], x[2i+1])
9    #[default]
10    Normal,
11    /// NeoX/Qwen2 style: first half paired with second half (x[i], x[i+d/2])
12    NeoX,
13}
14
15/// Configuration for Rotary Position Embeddings (RoPE)
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RopeConfig {
18    /// Base frequency for RoPE (typically 10000.0)
19    pub freq_base: f32,
20    /// Frequency scale factor
21    pub freq_scale: f32,
22    /// Number of dimensions to apply RoPE to (usually head_dim)
23    pub n_dims: usize,
24    /// RoPE scaling type
25    pub scaling_type: RopeScalingType,
26    /// Original context length (for scaled RoPE)
27    pub original_max_position_embeddings: usize,
28    /// RoPE implementation type (Normal vs NeoX)
29    pub rope_type: RopeType,
30    /// MRoPE dimension sections (number of pairs per axis).
31    /// E.g., [11, 11, 10] means 3 axes with 22, 22, 20 dims each.
32    /// When present, frequency index resets per section.
33    pub mrope_sections: Option<Vec<usize>>,
34}
35
36impl Default for RopeConfig {
37    fn default() -> Self {
38        Self {
39            freq_base: 10000.0,
40            freq_scale: 1.0,
41            n_dims: 0, // Will be set from head_dim
42            scaling_type: RopeScalingType::None,
43            original_max_position_embeddings: 2048,
44            rope_type: RopeType::Normal,
45            mrope_sections: None,
46        }
47    }
48}
49
50/// RoPE scaling types for extended context
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
52pub enum RopeScalingType {
53    /// No scaling
54    #[default]
55    None,
56    /// Linear scaling (divide positions by factor)
57    Linear,
58    /// YaRN (Yet another RoPE extension)
59    Yarn,
60    /// Dynamic NTK-aware scaling
61    DynamicNtk,
62}
63
64/// Attention layer type for architectures with heterogeneous attention
65/// (e.g., Gemma 4's 5:1 sliding/global pattern).
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum AttentionLayerType {
68    /// Sliding-window attention with local context.
69    Sliding,
70    /// Full global attention over the entire sequence.
71    Global,
72}
73
74/// Per-layer attention configuration for heterogeneous architectures.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct AttentionLayerConfig {
77    /// Whether this layer uses sliding or global attention.
78    pub layer_type: AttentionLayerType,
79    /// Per-head dimension for this layer's K/Q projections.
80    pub head_dim: usize,
81    /// Number of KV heads for this layer.
82    pub num_kv_heads: usize,
83    /// RoPE frequency base for this layer.
84    pub rope_freq_base: f32,
85    /// Number of head dimensions to apply RoPE to.
86    pub rope_dims: usize,
87    /// Sliding window size (0 = full attention).
88    pub sliding_window: usize,
89}
90
91/// Full model configuration
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ModelConfig {
94    /// Vocabulary size
95    pub vocab_size: usize,
96    /// Hidden dimension (embedding size)
97    pub hidden_size: usize,
98    /// Intermediate size (FFN dimension, typically 4 * hidden_size or computed)
99    pub intermediate_size: usize,
100    /// Number of transformer layers
101    pub num_layers: usize,
102    /// Number of attention heads
103    pub num_heads: usize,
104    /// Number of key-value heads (for GQA/MQA)
105    pub num_kv_heads: usize,
106    /// Dimension per head
107    pub head_dim: usize,
108    /// Maximum sequence length
109    pub max_seq_len: usize,
110    /// RMS normalization epsilon
111    pub norm_eps: f32,
112    /// RoPE configuration
113    pub rope_config: RopeConfig,
114    /// Whether to use parallel attention (compute QKV in parallel)
115    pub use_parallel_residual: bool,
116    /// Activation function type
117    pub hidden_act: ActivationType,
118    /// Whether there's a bias in attention projections
119    pub attention_bias: bool,
120    /// Whether there's a bias in MLP layers
121    pub mlp_bias: bool,
122    /// Tie word embeddings with output projection
123    pub tie_word_embeddings: bool,
124    /// Number of MoE experts (0 = dense model)
125    pub num_experts: usize,
126    /// Number of experts activated per token
127    pub num_experts_per_token: usize,
128    /// Expert FFN intermediate dimension (may differ from dense intermediate_size)
129    pub expert_intermediate_size: usize,
130    /// Per-head key dimension (defaults to head_dim if not specified)
131    pub key_length: usize,
132    /// Per-head value dimension (defaults to head_dim if not specified)
133    pub value_length: usize,
134    /// SSM/DeltaNet inner dimension (0 = no SSM layers)
135    pub ssm_d_inner: usize,
136    /// SSM state dimension (per-head key dim for delta-net)
137    pub ssm_d_state: usize,
138    /// SSM group count (number of key heads in delta-net)
139    pub ssm_n_group: usize,
140    /// SSM time step rank (number of value heads in delta-net)
141    pub ssm_dt_rank: usize,
142    /// SSM convolution kernel size
143    pub ssm_conv_kernel: usize,
144    /// Attention logit soft-capping value (Gemma2: 50.0, 0.0 = disabled)
145    pub attn_logit_softcap: f32,
146    /// Final logit soft-capping value (Gemma2: 30.0, 0.0 = disabled)
147    pub final_logit_softcap: f32,
148    /// Sliding window attention size (0 = disabled)
149    pub sliding_window: usize,
150    /// Whether this architecture uses combined QKV tensor
151    pub has_combined_qkv: bool,
152    /// Whether this architecture uses LayerNorm instead of RMSNorm
153    pub uses_layer_norm: bool,
154    /// Whether this architecture uses GELU activation
155    pub uses_gelu: bool,
156    /// Whether this architecture has a gate projection in FFN
157    pub has_ffn_gate: bool,
158    /// Per-layer attention configs for architectures with heterogeneous
159    /// attention (e.g., Gemma 4). None means all layers use uniform config.
160    pub attention_layer_configs: Option<Vec<AttentionLayerConfig>>,
161    /// Maps layer index to physical KV cache slot. Identity by default.
162    /// For KV shared layers, multiple indices map to the same slot.
163    pub kv_source_layer: Option<Vec<usize>>,
164}
165
166impl Default for ModelConfig {
167    fn default() -> Self {
168        Self {
169            vocab_size: 32000,
170            hidden_size: 4096,
171            intermediate_size: 11008,
172            num_layers: 32,
173            num_heads: 32,
174            num_kv_heads: 32,
175            head_dim: 128,
176            max_seq_len: 2048,
177            norm_eps: 1e-5,
178            rope_config: RopeConfig::default(),
179            use_parallel_residual: false,
180            hidden_act: ActivationType::SiLU,
181            attention_bias: false,
182            mlp_bias: false,
183            tie_word_embeddings: false,
184            num_experts: 0,
185            num_experts_per_token: 0,
186            expert_intermediate_size: 0,
187            key_length: 128,
188            value_length: 128,
189            ssm_d_inner: 0,
190            ssm_d_state: 0,
191            ssm_n_group: 0,
192            ssm_dt_rank: 0,
193            ssm_conv_kernel: 0,
194            attn_logit_softcap: 0.0,
195            final_logit_softcap: 0.0,
196            sliding_window: 0,
197            has_combined_qkv: false,
198            uses_layer_norm: false,
199            uses_gelu: false,
200            has_ffn_gate: true,
201            attention_layer_configs: None,
202            kv_source_layer: None,
203        }
204    }
205}
206
207impl ModelConfig {
208    /// Whether this model has SSM/delta-net recurrent layers
209    pub fn has_ssm(&self) -> bool {
210        self.ssm_d_inner > 0
211    }
212
213    /// Check if this is an MoE model
214    pub fn is_moe(&self) -> bool {
215        self.num_experts > 0
216    }
217
218    /// Create config for LLaMA 7B
219    pub fn llama_7b() -> Self {
220        Self {
221            vocab_size: 32000,
222            hidden_size: 4096,
223            intermediate_size: 11008,
224            num_layers: 32,
225            num_heads: 32,
226            num_kv_heads: 32,
227            head_dim: 128,
228            max_seq_len: 2048,
229            norm_eps: 1e-5,
230            rope_config: RopeConfig {
231                freq_base: 10000.0,
232                freq_scale: 1.0,
233                n_dims: 128,
234                scaling_type: RopeScalingType::None,
235                original_max_position_embeddings: 2048,
236                rope_type: RopeType::Normal,
237                mrope_sections: None,
238            },
239            use_parallel_residual: false,
240            hidden_act: ActivationType::SiLU,
241            attention_bias: false,
242            mlp_bias: false,
243            tie_word_embeddings: false,
244            num_experts: 0,
245            num_experts_per_token: 0,
246            expert_intermediate_size: 0,
247            key_length: 128,
248            value_length: 128,
249            ssm_d_inner: 0,
250            ssm_d_state: 0,
251            ssm_n_group: 0,
252            ssm_dt_rank: 0,
253            ssm_conv_kernel: 0,
254            attn_logit_softcap: 0.0,
255            final_logit_softcap: 0.0,
256            sliding_window: 0,
257            has_combined_qkv: false,
258            uses_layer_norm: false,
259            uses_gelu: false,
260            has_ffn_gate: true,
261            attention_layer_configs: None,
262            kv_source_layer: None,
263        }
264    }
265
266    /// Create config for LLaMA 2 7B
267    pub fn llama2_7b() -> Self {
268        let mut config = Self::llama_7b();
269        config.max_seq_len = 4096;
270        config.rope_config.original_max_position_embeddings = 4096;
271        config.attn_logit_softcap = 0.0;
272        config.final_logit_softcap = 0.0;
273        config.sliding_window = 0;
274        config.has_combined_qkv = false;
275        config.uses_layer_norm = false;
276        config.uses_gelu = false;
277        config.has_ffn_gate = true;
278        config
279    }
280
281    /// Create config for LLaMA 3 8B
282    pub fn llama3_8b() -> Self {
283        Self {
284            vocab_size: 128256,
285            hidden_size: 4096,
286            intermediate_size: 14336,
287            num_layers: 32,
288            num_heads: 32,
289            num_kv_heads: 8, // GQA
290            head_dim: 128,
291            max_seq_len: 8192,
292            norm_eps: 1e-5,
293            rope_config: RopeConfig {
294                freq_base: 500000.0,
295                freq_scale: 1.0,
296                n_dims: 128,
297                scaling_type: RopeScalingType::None,
298                original_max_position_embeddings: 8192,
299                rope_type: RopeType::Normal,
300                mrope_sections: None,
301            },
302            use_parallel_residual: false,
303            hidden_act: ActivationType::SiLU,
304            attention_bias: false,
305            mlp_bias: false,
306            tie_word_embeddings: false,
307            num_experts: 0,
308            num_experts_per_token: 0,
309            expert_intermediate_size: 0,
310            key_length: 128,
311            value_length: 128,
312            ssm_d_inner: 0,
313            ssm_d_state: 0,
314            ssm_n_group: 0,
315            ssm_dt_rank: 0,
316            ssm_conv_kernel: 0,
317            attn_logit_softcap: 0.0,
318            final_logit_softcap: 0.0,
319            sliding_window: 0,
320            has_combined_qkv: false,
321            uses_layer_norm: false,
322            uses_gelu: false,
323            has_ffn_gate: true,
324            attention_layer_configs: None,
325            kv_source_layer: None,
326        }
327    }
328
329    /// Check if this model uses Grouped Query Attention
330    pub fn uses_gqa(&self) -> bool {
331        self.num_kv_heads < self.num_heads
332    }
333
334    /// Get the number of query heads per KV head
335    pub fn num_queries_per_kv(&self) -> usize {
336        self.num_heads / self.num_kv_heads
337    }
338
339    /// Build attention layer configs for a sliding/global pattern.
340    ///
341    /// `pattern_period` layers form one cycle, where the last layer is Global
342    /// and the rest are Sliding. E.g., period=6 gives 5 sliding + 1 global.
343    pub fn build_attention_layer_configs(
344        num_layers: usize,
345        pattern_period: usize,
346        sliding_head_dim: usize,
347        sliding_kv_heads: usize,
348        sliding_rope_freq_base: f32,
349        sliding_window: usize,
350        global_head_dim: usize,
351        global_kv_heads: usize,
352        global_rope_freq_base: f32,
353        global_rope_dims: usize,
354    ) -> Vec<AttentionLayerConfig> {
355        (0..num_layers)
356            .map(|i| {
357                if i % pattern_period == pattern_period - 1 {
358                    AttentionLayerConfig {
359                        layer_type: AttentionLayerType::Global,
360                        head_dim: global_head_dim,
361                        num_kv_heads: global_kv_heads,
362                        rope_freq_base: global_rope_freq_base,
363                        rope_dims: global_rope_dims,
364                        sliding_window: 0,
365                    }
366                } else {
367                    AttentionLayerConfig {
368                        layer_type: AttentionLayerType::Sliding,
369                        head_dim: sliding_head_dim,
370                        num_kv_heads: sliding_kv_heads,
371                        rope_freq_base: sliding_rope_freq_base,
372                        rope_dims: sliding_head_dim,
373                        sliding_window,
374                    }
375                }
376            })
377            .collect()
378    }
379
380    /// Build attention layer configs from a per-layer boolean SWA pattern.
381    ///
382    /// `is_swa[i]` is true if layer `i` uses sliding-window attention, false
383    /// for global attention. This matches the `sliding_window_pattern` array
384    /// stored in Gemma 4 GGUF files.
385    #[allow(clippy::too_many_arguments)]
386    pub fn build_attention_layer_configs_from_pattern(
387        is_swa: &[bool],
388        sliding_head_dim: usize,
389        sliding_kv_heads: usize,
390        sliding_rope_freq_base: f32,
391        sliding_rope_dims: usize,
392        sliding_window: usize,
393        global_head_dim: usize,
394        global_kv_heads: usize,
395        global_rope_freq_base: f32,
396        global_rope_dims: usize,
397    ) -> Vec<AttentionLayerConfig> {
398        is_swa
399            .iter()
400            .map(|&swa| {
401                if swa {
402                    AttentionLayerConfig {
403                        layer_type: AttentionLayerType::Sliding,
404                        head_dim: sliding_head_dim,
405                        num_kv_heads: sliding_kv_heads,
406                        rope_freq_base: sliding_rope_freq_base,
407                        rope_dims: sliding_rope_dims,
408                        sliding_window,
409                    }
410                } else {
411                    AttentionLayerConfig {
412                        layer_type: AttentionLayerType::Global,
413                        head_dim: global_head_dim,
414                        num_kv_heads: global_kv_heads,
415                        rope_freq_base: global_rope_freq_base,
416                        rope_dims: global_rope_dims,
417                        sliding_window: 0,
418                    }
419                }
420            })
421            .collect()
422    }
423
424    /// Build KV source layer mapping for shared KV cache.
425    ///
426    /// The last `shared_layers` layers reuse cached K/V from earlier layers
427    /// instead of projecting their own. The mapping is TYPE-SPECIFIC: shared
428    /// SWA layers map to the last KV-owning SWA layer, shared global layers
429    /// map to the last KV-owning global layer.
430    ///
431    /// Requires `layer_configs` to determine each layer's type.
432    pub fn build_kv_source_mapping(
433        num_layers: usize,
434        shared_layers: usize,
435        layer_configs: &[AttentionLayerConfig],
436    ) -> Vec<usize> {
437        if shared_layers == 0 || shared_layers >= num_layers {
438            return (0..num_layers).collect();
439        }
440        let kv_boundary = num_layers - shared_layers;
441
442        // Find the last KV-owning layer for each type
443        let mut last_swa_kv = 0;
444        let mut last_global_kv = 0;
445        for i in 0..kv_boundary {
446            match layer_configs[i].layer_type {
447                AttentionLayerType::Sliding => last_swa_kv = i,
448                AttentionLayerType::Global => last_global_kv = i,
449            }
450        }
451
452        (0..num_layers)
453            .map(|i| {
454                if i < kv_boundary {
455                    i // owns its own KV cache
456                } else {
457                    // Shared: map to last KV-owning layer of same type
458                    match layer_configs[i].layer_type {
459                        AttentionLayerType::Sliding => last_swa_kv,
460                        AttentionLayerType::Global => last_global_kv,
461                    }
462                }
463            })
464            .collect()
465    }
466}
467
468/// Activation function types
469#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
470pub enum ActivationType {
471    /// Gaussian Error Linear Unit
472    GELU,
473    /// GELU approximation (tanh-based)
474    GELUApprox,
475    /// Sigmoid Linear Unit (Swish)
476    #[default]
477    SiLU,
478    /// Rectified Linear Unit
479    ReLU,
480    /// Squared ReLU
481    ReLUSquared,
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_default_config() {
490        let config = ModelConfig::default();
491        assert_eq!(config.vocab_size, 32000);
492        assert_eq!(config.hidden_size, 4096);
493        assert_eq!(config.num_layers, 32);
494    }
495
496    #[test]
497    fn test_llama3_gqa() {
498        let config = ModelConfig::llama3_8b();
499        assert!(config.uses_gqa());
500        assert_eq!(config.num_queries_per_kv(), 4);
501    }
502
503    #[test]
504    fn test_llama_no_gqa() {
505        let config = ModelConfig::llama_7b();
506        assert!(!config.uses_gqa());
507        assert_eq!(config.num_queries_per_kv(), 1);
508    }
509
510    #[test]
511    fn test_attention_layer_configs_pattern() {
512        let configs = ModelConfig::build_attention_layer_configs(
513            12, 6, 256, 4, 10000.0, 1024, 512, 2, 1_000_000.0, 128,
514        );
515        assert_eq!(configs.len(), 12);
516        for i in 0..12 {
517            if i % 6 == 5 {
518                assert_eq!(configs[i].layer_type, AttentionLayerType::Global);
519                assert_eq!(configs[i].head_dim, 512);
520                assert_eq!(configs[i].num_kv_heads, 2);
521                assert_eq!(configs[i].sliding_window, 0);
522                assert_eq!(configs[i].rope_dims, 128);
523            } else {
524                assert_eq!(configs[i].layer_type, AttentionLayerType::Sliding);
525                assert_eq!(configs[i].head_dim, 256);
526                assert_eq!(configs[i].num_kv_heads, 4);
527                assert_eq!(configs[i].sliding_window, 1024);
528                assert_eq!(configs[i].rope_dims, 256);
529            }
530        }
531    }
532
533    #[test]
534    fn test_attention_layer_configs_from_bool_pattern() {
535        // Gemma 4 E2B actual pattern: true=SWA, false=global, 4:1 repeating over 35 layers
536        let pattern: Vec<bool> = (0..35).map(|i| i % 5 != 4).collect();
537        let configs = ModelConfig::build_attention_layer_configs_from_pattern(
538            &pattern, 256, 1, 10000.0, 256, 512, 512, 1, 1_000_000.0, 512,
539        );
540        assert_eq!(configs.len(), 35);
541        assert_eq!(configs[0].layer_type, AttentionLayerType::Sliding);
542        assert_eq!(configs[0].head_dim, 256);
543        assert_eq!(configs[4].layer_type, AttentionLayerType::Global);
544        assert_eq!(configs[4].head_dim, 512);
545        assert_eq!(configs[4].sliding_window, 0);
546        assert_eq!(configs[34].layer_type, AttentionLayerType::Global);
547    }
548
549    #[test]
550    fn test_kv_source_mapping_no_sharing() {
551        let configs = ModelConfig::build_attention_layer_configs(
552            6, 6, 256, 4, 10000.0, 1024, 512, 2, 1_000_000.0, 128,
553        );
554        let mapping = ModelConfig::build_kv_source_mapping(6, 0, &configs);
555        assert_eq!(mapping, (0..6).collect::<Vec<_>>());
556    }
557
558    #[test]
559    fn test_kv_source_mapping_type_specific() {
560        // 12 layers, pattern period 5 (4 SWA + 1 global), 7 shared layers
561        // KV-owning: layers 0-4, Shared: layers 5-11
562        // Layer types: 0=SWA,1=SWA,2=SWA,3=SWA,4=Global, 5=SWA,...,9=Global, 10=SWA,11=SWA
563        let configs = ModelConfig::build_attention_layer_configs(
564            12, 5, 256, 4, 10000.0, 1024, 512, 2, 1_000_000.0, 128,
565        );
566        let mapping = ModelConfig::build_kv_source_mapping(12, 7, &configs);
567        assert_eq!(mapping.len(), 12);
568        // Layers 0-4 own their cache
569        for i in 0..5 {
570            assert_eq!(mapping[i], i, "layer {i}");
571        }
572        // Last KV-owning SWA layer = 3, last KV-owning global layer = 4
573        // Shared layers map by type:
574        assert_eq!(mapping[5], 3);  // SWA -> last SWA (3)
575        assert_eq!(mapping[6], 3);  // SWA -> 3
576        assert_eq!(mapping[7], 3);  // SWA -> 3
577        assert_eq!(mapping[8], 3);  // SWA -> 3
578        assert_eq!(mapping[9], 4);  // Global -> last global (4)
579        assert_eq!(mapping[10], 3); // SWA -> 3
580        assert_eq!(mapping[11], 3); // SWA -> 3
581    }
582}