inference_lab/config/model.rs
1use serde::Deserialize;
2
3#[derive(Debug, Clone, Deserialize)]
4pub struct ModelConfig {
5 /// Model name
6 pub name: String,
7
8 /// Total parameters in the model (all parameters, including inactive experts in MoE)
9 pub num_parameters: u64,
10
11 /// Active parameters used during inference (for MoE models with sparse activation)
12 /// If not specified, defaults to num_parameters (dense models)
13 #[serde(default)]
14 pub num_active_parameters: Option<u64>,
15
16 /// Number of transformer layers
17 pub num_layers: u32,
18
19 /// Hidden dimension
20 pub hidden_dim: u32,
21
22 /// Number of attention heads
23 pub num_heads: u32,
24
25 /// Number of KV heads (for GQA/MQA). If not specified, defaults to num_heads (MHA)
26 #[serde(default)]
27 pub num_kv_heads: Option<u32>,
28
29 /// Maximum sequence length supported
30 pub max_seq_len: u32,
31
32 /// Sliding window size for sliding window attention layers (None = no sliding window)
33 /// Only applies to layers marked as using sliding window attention
34 #[serde(default)]
35 pub sliding_window: Option<u32>,
36
37 /// Number of layers using sliding window attention (rest use full attention)
38 /// If not specified, defaults to 0 (all layers use full attention)
39 #[serde(default)]
40 pub num_sliding_layers: Option<u32>,
41
42 /// KV cache size per token per layer (in bytes)
43 /// For GQA: 2 * num_kv_heads * head_dim * bytes_per_param * num_layers
44 /// For MHA: 2 * num_heads * head_dim * bytes_per_param * num_layers
45 #[serde(skip)]
46 pub kv_cache_bytes_per_token: u64,
47}
48
49impl ModelConfig {
50 /// Get the number of active parameters (defaults to total parameters for dense models)
51 pub fn active_parameters(&self) -> u64 {
52 self.num_active_parameters.unwrap_or(self.num_parameters)
53 }
54
55 /// Calculate and set the KV cache size per token
56 /// For models with sliding window attention, this calculates an average based on typical usage
57 pub fn compute_kv_cache_size(&mut self, bytes_per_param: u32) {
58 // Use num_kv_heads if specified (GQA/MQA), otherwise use num_heads (MHA)
59 let kv_heads = self.num_kv_heads.unwrap_or(self.num_heads);
60 let head_dim = self.hidden_dim / self.num_heads;
61
62 // Bytes per token per layer
63 let bytes_per_token_per_layer = 2 * kv_heads as u64 * head_dim as u64 * bytes_per_param as u64;
64
65 // If no sliding window, all layers use full attention
66 if self.sliding_window.is_none() || self.num_sliding_layers.is_none() {
67 self.kv_cache_bytes_per_token = bytes_per_token_per_layer * self.num_layers as u64;
68 return;
69 }
70
71 // With sliding window: some layers cap at window size, others grow with sequence
72 let _num_sliding = self.num_sliding_layers.unwrap_or(0);
73
74 // All layers contribute equally per token (sliding window just caps maximum)
75 // At short sequences (< window), all layers grow at same rate
76 // This is correct for the initial growth phase
77 self.kv_cache_bytes_per_token = bytes_per_token_per_layer * self.num_layers as u64;
78 }
79
80 /// Initialize with KV cache size pre-computed
81 pub fn with_kv_cache_size(mut self, bytes_per_param: u32) -> Self {
82 self.compute_kv_cache_size(bytes_per_param);
83 self
84 }
85
86 /// Calculate total KV cache size for a sequence, accounting for sliding window
87 pub fn kv_cache_size_for_sequence(&self, seq_len: u32) -> u64 {
88 // Use num_kv_heads if specified (GQA/MQA), otherwise use num_heads (MHA)
89 let kv_heads = self.num_kv_heads.unwrap_or(self.num_heads);
90 let head_dim = self.hidden_dim / self.num_heads;
91 let bytes_per_token_per_layer = 2 * kv_heads as u64 * head_dim as u64 * 1; // Assuming bytes_per_param
92
93 // No sliding window: simple linear growth
94 if self.sliding_window.is_none() || self.num_sliding_layers.is_none() {
95 return self.kv_cache_bytes_per_token * seq_len as u64;
96 }
97
98 let window = self.sliding_window.unwrap();
99 let num_sliding = self.num_sliding_layers.unwrap_or(0);
100 let num_full = self.num_layers.saturating_sub(num_sliding);
101
102 // Full attention layers: grow linearly with sequence length
103 let full_layers_kv = bytes_per_token_per_layer * num_full as u64 * seq_len as u64;
104
105 // Sliding window layers: capped at window size
106 let sliding_layers_kv = bytes_per_token_per_layer * num_sliding as u64 * seq_len.min(window) as u64;
107
108 full_layers_kv + sliding_layers_kv
109 }
110}