inference_lab/config/model.rs
1use serde::Deserialize;
2
3#[derive(Debug, Clone, Deserialize)]
4pub struct ModelConfig {
5 /// Model name
6 pub name: String,
7
8 /// Total parameters in the model (all parameters, including inactive experts in MoE)
9 pub num_parameters: u64,
10
11 /// Active parameters used during inference (for MoE models with sparse activation)
12 /// If not specified, defaults to num_parameters (dense models)
13 #[serde(default)]
14 pub num_active_parameters: Option<u64>,
15
16 /// Number of transformer layers
17 pub num_layers: u32,
18
19 /// Hidden dimension
20 pub hidden_dim: u32,
21
22 /// Number of attention heads
23 pub num_heads: u32,
24
25 /// Number of KV heads (for GQA/MQA). If not specified, defaults to num_heads (MHA)
26 #[serde(default)]
27 pub num_kv_heads: Option<u32>,
28
29 /// Maximum sequence length supported
30 pub max_seq_len: u32,
31
32 /// Sliding window size for sliding window attention layers (None = no sliding window)
33 /// Only applies to layers marked as using sliding window attention
34 #[serde(default)]
35 pub sliding_window: Option<u32>,
36
37 /// Number of layers using sliding window attention (rest use full attention)
38 /// If not specified, defaults to 0 (all layers use full attention)
39 #[serde(default)]
40 pub num_sliding_layers: Option<u32>,
41
42 /// KV cache size per token per layer (in bytes)
43 /// For GQA: 2 * num_kv_heads * head_dim * bytes_per_param * num_layers
44 /// For MHA: 2 * num_heads * head_dim * bytes_per_param * num_layers
45 #[serde(skip)]
46 pub kv_cache_bytes_per_token: u64,
47}
48
49impl ModelConfig {
50 /// Get the number of active parameters (defaults to total parameters for dense models)
51 pub fn active_parameters(&self) -> u64 {
52 self.num_active_parameters.unwrap_or(self.num_parameters)
53 }
54
55 /// Calculate and set the KV cache size per token
56 /// For models with sliding window attention, this calculates an average based on typical usage
57 pub fn compute_kv_cache_size(&mut self, bytes_per_param: u32) {
58 // Use num_kv_heads if specified (GQA/MQA), otherwise use num_heads (MHA)
59 let kv_heads = self.num_kv_heads.unwrap_or(self.num_heads);
60 let head_dim = self.hidden_dim / self.num_heads;
61
62 // Bytes per token per layer
63 let bytes_per_token_per_layer =
64 2 * kv_heads as u64 * head_dim as u64 * bytes_per_param as u64;
65
66 // If no sliding window, all layers use full attention
67 if self.sliding_window.is_none() || self.num_sliding_layers.is_none() {
68 self.kv_cache_bytes_per_token = bytes_per_token_per_layer * self.num_layers as u64;
69 return;
70 }
71
72 // With sliding window: some layers cap at window size, others grow with sequence
73 let _num_sliding = self.num_sliding_layers.unwrap_or(0);
74
75 // All layers contribute equally per token (sliding window just caps maximum)
76 // At short sequences (< window), all layers grow at same rate
77 // This is correct for the initial growth phase
78 self.kv_cache_bytes_per_token = bytes_per_token_per_layer * self.num_layers as u64;
79 }
80
81 /// Initialize with KV cache size pre-computed
82 pub fn with_kv_cache_size(mut self, bytes_per_param: u32) -> Self {
83 self.compute_kv_cache_size(bytes_per_param);
84 self
85 }
86
87 /// Calculate total KV cache size for a sequence, accounting for sliding window
88 pub fn kv_cache_size_for_sequence(&self, seq_len: u32) -> u64 {
89 let bytes_per_token_per_layer = self.kv_cache_bytes_per_token / self.num_layers as u64;
90
91 // No sliding window: simple linear growth
92 if self.sliding_window.is_none() || self.num_sliding_layers.is_none() {
93 return self.kv_cache_bytes_per_token * seq_len as u64;
94 }
95
96 let window = self.sliding_window.unwrap();
97 let num_sliding = self.num_sliding_layers.unwrap_or(0);
98 let num_full = self.num_layers.saturating_sub(num_sliding);
99
100 // Full attention layers: grow linearly with sequence length
101 let full_layers_kv = bytes_per_token_per_layer * num_full as u64 * seq_len as u64;
102
103 // Sliding window layers: capped at window size
104 let sliding_layers_kv =
105 bytes_per_token_per_layer * num_sliding as u64 * seq_len.min(window) as u64;
106
107 full_layers_kv + sliding_layers_kv
108 }
109}