1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
7pub enum RopeType {
8 #[default]
10 Normal,
11 NeoX,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RopeConfig {
18 pub freq_base: f32,
20 pub freq_scale: f32,
22 pub n_dims: usize,
24 pub scaling_type: RopeScalingType,
26 pub original_max_position_embeddings: usize,
28 pub rope_type: RopeType,
30 pub mrope_sections: Option<Vec<usize>>,
34}
35
36impl Default for RopeConfig {
37 fn default() -> Self {
38 Self {
39 freq_base: 10000.0,
40 freq_scale: 1.0,
41 n_dims: 0, scaling_type: RopeScalingType::None,
43 original_max_position_embeddings: 2048,
44 rope_type: RopeType::Normal,
45 mrope_sections: None,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
52pub enum RopeScalingType {
53 #[default]
55 None,
56 Linear,
58 Yarn,
60 DynamicNtk,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum AttentionLayerType {
68 Sliding,
70 Global,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct AttentionLayerConfig {
77 pub layer_type: AttentionLayerType,
79 pub head_dim: usize,
81 pub num_kv_heads: usize,
83 pub rope_freq_base: f32,
85 pub rope_dims: usize,
87 pub sliding_window: usize,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ModelConfig {
94 pub vocab_size: usize,
96 pub hidden_size: usize,
98 pub intermediate_size: usize,
100 pub num_layers: usize,
102 pub num_heads: usize,
104 pub num_kv_heads: usize,
106 pub head_dim: usize,
108 pub max_seq_len: usize,
110 pub norm_eps: f32,
112 pub rope_config: RopeConfig,
114 pub use_parallel_residual: bool,
116 pub hidden_act: ActivationType,
118 pub attention_bias: bool,
120 pub mlp_bias: bool,
122 pub tie_word_embeddings: bool,
124 pub num_experts: usize,
126 pub num_experts_per_token: usize,
128 pub expert_intermediate_size: usize,
130 pub key_length: usize,
132 pub value_length: usize,
134 pub ssm_d_inner: usize,
136 pub ssm_d_state: usize,
138 pub ssm_n_group: usize,
140 pub ssm_dt_rank: usize,
142 pub ssm_conv_kernel: usize,
144 pub attn_logit_softcap: f32,
146 pub final_logit_softcap: f32,
148 pub sliding_window: usize,
150 pub has_combined_qkv: bool,
152 pub uses_layer_norm: bool,
154 pub uses_gelu: bool,
156 pub has_ffn_gate: bool,
158 pub attention_layer_configs: Option<Vec<AttentionLayerConfig>>,
161 pub kv_source_layer: Option<Vec<usize>>,
164}
165
166impl Default for ModelConfig {
167 fn default() -> Self {
168 Self {
169 vocab_size: 32000,
170 hidden_size: 4096,
171 intermediate_size: 11008,
172 num_layers: 32,
173 num_heads: 32,
174 num_kv_heads: 32,
175 head_dim: 128,
176 max_seq_len: 2048,
177 norm_eps: 1e-5,
178 rope_config: RopeConfig::default(),
179 use_parallel_residual: false,
180 hidden_act: ActivationType::SiLU,
181 attention_bias: false,
182 mlp_bias: false,
183 tie_word_embeddings: false,
184 num_experts: 0,
185 num_experts_per_token: 0,
186 expert_intermediate_size: 0,
187 key_length: 128,
188 value_length: 128,
189 ssm_d_inner: 0,
190 ssm_d_state: 0,
191 ssm_n_group: 0,
192 ssm_dt_rank: 0,
193 ssm_conv_kernel: 0,
194 attn_logit_softcap: 0.0,
195 final_logit_softcap: 0.0,
196 sliding_window: 0,
197 has_combined_qkv: false,
198 uses_layer_norm: false,
199 uses_gelu: false,
200 has_ffn_gate: true,
201 attention_layer_configs: None,
202 kv_source_layer: None,
203 }
204 }
205}
206
207impl ModelConfig {
208 pub fn has_ssm(&self) -> bool {
210 self.ssm_d_inner > 0
211 }
212
213 pub fn is_moe(&self) -> bool {
215 self.num_experts > 0
216 }
217
218 pub fn llama_7b() -> Self {
220 Self {
221 vocab_size: 32000,
222 hidden_size: 4096,
223 intermediate_size: 11008,
224 num_layers: 32,
225 num_heads: 32,
226 num_kv_heads: 32,
227 head_dim: 128,
228 max_seq_len: 2048,
229 norm_eps: 1e-5,
230 rope_config: RopeConfig {
231 freq_base: 10000.0,
232 freq_scale: 1.0,
233 n_dims: 128,
234 scaling_type: RopeScalingType::None,
235 original_max_position_embeddings: 2048,
236 rope_type: RopeType::Normal,
237 mrope_sections: None,
238 },
239 use_parallel_residual: false,
240 hidden_act: ActivationType::SiLU,
241 attention_bias: false,
242 mlp_bias: false,
243 tie_word_embeddings: false,
244 num_experts: 0,
245 num_experts_per_token: 0,
246 expert_intermediate_size: 0,
247 key_length: 128,
248 value_length: 128,
249 ssm_d_inner: 0,
250 ssm_d_state: 0,
251 ssm_n_group: 0,
252 ssm_dt_rank: 0,
253 ssm_conv_kernel: 0,
254 attn_logit_softcap: 0.0,
255 final_logit_softcap: 0.0,
256 sliding_window: 0,
257 has_combined_qkv: false,
258 uses_layer_norm: false,
259 uses_gelu: false,
260 has_ffn_gate: true,
261 attention_layer_configs: None,
262 kv_source_layer: None,
263 }
264 }
265
266 pub fn llama2_7b() -> Self {
268 let mut config = Self::llama_7b();
269 config.max_seq_len = 4096;
270 config.rope_config.original_max_position_embeddings = 4096;
271 config.attn_logit_softcap = 0.0;
272 config.final_logit_softcap = 0.0;
273 config.sliding_window = 0;
274 config.has_combined_qkv = false;
275 config.uses_layer_norm = false;
276 config.uses_gelu = false;
277 config.has_ffn_gate = true;
278 config
279 }
280
281 pub fn llama3_8b() -> Self {
283 Self {
284 vocab_size: 128256,
285 hidden_size: 4096,
286 intermediate_size: 14336,
287 num_layers: 32,
288 num_heads: 32,
289 num_kv_heads: 8, head_dim: 128,
291 max_seq_len: 8192,
292 norm_eps: 1e-5,
293 rope_config: RopeConfig {
294 freq_base: 500000.0,
295 freq_scale: 1.0,
296 n_dims: 128,
297 scaling_type: RopeScalingType::None,
298 original_max_position_embeddings: 8192,
299 rope_type: RopeType::Normal,
300 mrope_sections: None,
301 },
302 use_parallel_residual: false,
303 hidden_act: ActivationType::SiLU,
304 attention_bias: false,
305 mlp_bias: false,
306 tie_word_embeddings: false,
307 num_experts: 0,
308 num_experts_per_token: 0,
309 expert_intermediate_size: 0,
310 key_length: 128,
311 value_length: 128,
312 ssm_d_inner: 0,
313 ssm_d_state: 0,
314 ssm_n_group: 0,
315 ssm_dt_rank: 0,
316 ssm_conv_kernel: 0,
317 attn_logit_softcap: 0.0,
318 final_logit_softcap: 0.0,
319 sliding_window: 0,
320 has_combined_qkv: false,
321 uses_layer_norm: false,
322 uses_gelu: false,
323 has_ffn_gate: true,
324 attention_layer_configs: None,
325 kv_source_layer: None,
326 }
327 }
328
329 pub fn uses_gqa(&self) -> bool {
331 self.num_kv_heads < self.num_heads
332 }
333
334 pub fn num_queries_per_kv(&self) -> usize {
336 self.num_heads / self.num_kv_heads
337 }
338
339 pub fn build_attention_layer_configs(
344 num_layers: usize,
345 pattern_period: usize,
346 sliding_head_dim: usize,
347 sliding_kv_heads: usize,
348 sliding_rope_freq_base: f32,
349 sliding_window: usize,
350 global_head_dim: usize,
351 global_kv_heads: usize,
352 global_rope_freq_base: f32,
353 global_rope_dims: usize,
354 ) -> Vec<AttentionLayerConfig> {
355 (0..num_layers)
356 .map(|i| {
357 if i % pattern_period == pattern_period - 1 {
358 AttentionLayerConfig {
359 layer_type: AttentionLayerType::Global,
360 head_dim: global_head_dim,
361 num_kv_heads: global_kv_heads,
362 rope_freq_base: global_rope_freq_base,
363 rope_dims: global_rope_dims,
364 sliding_window: 0,
365 }
366 } else {
367 AttentionLayerConfig {
368 layer_type: AttentionLayerType::Sliding,
369 head_dim: sliding_head_dim,
370 num_kv_heads: sliding_kv_heads,
371 rope_freq_base: sliding_rope_freq_base,
372 rope_dims: sliding_head_dim,
373 sliding_window,
374 }
375 }
376 })
377 .collect()
378 }
379
380 #[allow(clippy::too_many_arguments)]
386 pub fn build_attention_layer_configs_from_pattern(
387 is_swa: &[bool],
388 sliding_head_dim: usize,
389 sliding_kv_heads: usize,
390 sliding_rope_freq_base: f32,
391 sliding_rope_dims: usize,
392 sliding_window: usize,
393 global_head_dim: usize,
394 global_kv_heads: usize,
395 global_rope_freq_base: f32,
396 global_rope_dims: usize,
397 ) -> Vec<AttentionLayerConfig> {
398 is_swa
399 .iter()
400 .map(|&swa| {
401 if swa {
402 AttentionLayerConfig {
403 layer_type: AttentionLayerType::Sliding,
404 head_dim: sliding_head_dim,
405 num_kv_heads: sliding_kv_heads,
406 rope_freq_base: sliding_rope_freq_base,
407 rope_dims: sliding_rope_dims,
408 sliding_window,
409 }
410 } else {
411 AttentionLayerConfig {
412 layer_type: AttentionLayerType::Global,
413 head_dim: global_head_dim,
414 num_kv_heads: global_kv_heads,
415 rope_freq_base: global_rope_freq_base,
416 rope_dims: global_rope_dims,
417 sliding_window: 0,
418 }
419 }
420 })
421 .collect()
422 }
423
424 pub fn build_kv_source_mapping(
433 num_layers: usize,
434 shared_layers: usize,
435 layer_configs: &[AttentionLayerConfig],
436 ) -> Vec<usize> {
437 if shared_layers == 0 || shared_layers >= num_layers {
438 return (0..num_layers).collect();
439 }
440 let kv_boundary = num_layers - shared_layers;
441
442 let mut last_swa_kv = 0;
444 let mut last_global_kv = 0;
445 for i in 0..kv_boundary {
446 match layer_configs[i].layer_type {
447 AttentionLayerType::Sliding => last_swa_kv = i,
448 AttentionLayerType::Global => last_global_kv = i,
449 }
450 }
451
452 (0..num_layers)
453 .map(|i| {
454 if i < kv_boundary {
455 i } else {
457 match layer_configs[i].layer_type {
459 AttentionLayerType::Sliding => last_swa_kv,
460 AttentionLayerType::Global => last_global_kv,
461 }
462 }
463 })
464 .collect()
465 }
466}
467
468#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
470pub enum ActivationType {
471 GELU,
473 GELUApprox,
475 #[default]
477 SiLU,
478 ReLU,
480 ReLUSquared,
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_default_config() {
490 let config = ModelConfig::default();
491 assert_eq!(config.vocab_size, 32000);
492 assert_eq!(config.hidden_size, 4096);
493 assert_eq!(config.num_layers, 32);
494 }
495
496 #[test]
497 fn test_llama3_gqa() {
498 let config = ModelConfig::llama3_8b();
499 assert!(config.uses_gqa());
500 assert_eq!(config.num_queries_per_kv(), 4);
501 }
502
503 #[test]
504 fn test_llama_no_gqa() {
505 let config = ModelConfig::llama_7b();
506 assert!(!config.uses_gqa());
507 assert_eq!(config.num_queries_per_kv(), 1);
508 }
509
510 #[test]
511 fn test_attention_layer_configs_pattern() {
512 let configs = ModelConfig::build_attention_layer_configs(
513 12, 6, 256, 4, 10000.0, 1024, 512, 2, 1_000_000.0, 128,
514 );
515 assert_eq!(configs.len(), 12);
516 for i in 0..12 {
517 if i % 6 == 5 {
518 assert_eq!(configs[i].layer_type, AttentionLayerType::Global);
519 assert_eq!(configs[i].head_dim, 512);
520 assert_eq!(configs[i].num_kv_heads, 2);
521 assert_eq!(configs[i].sliding_window, 0);
522 assert_eq!(configs[i].rope_dims, 128);
523 } else {
524 assert_eq!(configs[i].layer_type, AttentionLayerType::Sliding);
525 assert_eq!(configs[i].head_dim, 256);
526 assert_eq!(configs[i].num_kv_heads, 4);
527 assert_eq!(configs[i].sliding_window, 1024);
528 assert_eq!(configs[i].rope_dims, 256);
529 }
530 }
531 }
532
533 #[test]
534 fn test_attention_layer_configs_from_bool_pattern() {
535 let pattern: Vec<bool> = (0..35).map(|i| i % 5 != 4).collect();
537 let configs = ModelConfig::build_attention_layer_configs_from_pattern(
538 &pattern, 256, 1, 10000.0, 256, 512, 512, 1, 1_000_000.0, 512,
539 );
540 assert_eq!(configs.len(), 35);
541 assert_eq!(configs[0].layer_type, AttentionLayerType::Sliding);
542 assert_eq!(configs[0].head_dim, 256);
543 assert_eq!(configs[4].layer_type, AttentionLayerType::Global);
544 assert_eq!(configs[4].head_dim, 512);
545 assert_eq!(configs[4].sliding_window, 0);
546 assert_eq!(configs[34].layer_type, AttentionLayerType::Global);
547 }
548
549 #[test]
550 fn test_kv_source_mapping_no_sharing() {
551 let configs = ModelConfig::build_attention_layer_configs(
552 6, 6, 256, 4, 10000.0, 1024, 512, 2, 1_000_000.0, 128,
553 );
554 let mapping = ModelConfig::build_kv_source_mapping(6, 0, &configs);
555 assert_eq!(mapping, (0..6).collect::<Vec<_>>());
556 }
557
558 #[test]
559 fn test_kv_source_mapping_type_specific() {
560 let configs = ModelConfig::build_attention_layer_configs(
564 12, 5, 256, 4, 10000.0, 1024, 512, 2, 1_000_000.0, 128,
565 );
566 let mapping = ModelConfig::build_kv_source_mapping(12, 7, &configs);
567 assert_eq!(mapping.len(), 12);
568 for i in 0..5 {
570 assert_eq!(mapping[i], i, "layer {i}");
571 }
572 assert_eq!(mapping[5], 3); assert_eq!(mapping[6], 3); assert_eq!(mapping[7], 3); assert_eq!(mapping[8], 3); assert_eq!(mapping[9], 4); assert_eq!(mapping[10], 3); assert_eq!(mapping[11], 3); }
582}