Skip to main content

inference_lab/config/
model.rs

1use serde::Deserialize;
2
3#[derive(Debug, Clone, Deserialize)]
4pub struct ModelConfig {
5    /// Model name
6    pub name: String,
7
8    /// Total parameters in the model (all parameters, including inactive experts in MoE)
9    pub num_parameters: u64,
10
11    /// Active parameters used during inference (for MoE models with sparse activation)
12    /// If not specified, defaults to num_parameters (dense models)
13    #[serde(default)]
14    pub num_active_parameters: Option<u64>,
15
16    /// Number of transformer layers
17    pub num_layers: u32,
18
19    /// Hidden dimension
20    pub hidden_dim: u32,
21
22    /// Number of attention heads
23    pub num_heads: u32,
24
25    /// Number of KV heads (for GQA/MQA). If not specified, defaults to num_heads (MHA)
26    #[serde(default)]
27    pub num_kv_heads: Option<u32>,
28
29    /// Maximum sequence length supported
30    pub max_seq_len: u32,
31
32    /// Sliding window size for sliding window attention layers (None = no sliding window)
33    /// Only applies to layers marked as using sliding window attention
34    #[serde(default)]
35    pub sliding_window: Option<u32>,
36
37    /// Number of layers using sliding window attention (rest use full attention)
38    /// If not specified, defaults to 0 (all layers use full attention)
39    #[serde(default)]
40    pub num_sliding_layers: Option<u32>,
41
42    /// KV cache size per token per layer (in bytes)
43    /// For GQA: 2 * num_kv_heads * head_dim * bytes_per_param * num_layers
44    /// For MHA: 2 * num_heads * head_dim * bytes_per_param * num_layers
45    #[serde(skip)]
46    pub kv_cache_bytes_per_token: u64,
47}
48
49impl ModelConfig {
50    /// Get the number of active parameters (defaults to total parameters for dense models)
51    pub fn active_parameters(&self) -> u64 {
52        self.num_active_parameters.unwrap_or(self.num_parameters)
53    }
54
55    /// Calculate and set the KV cache size per token
56    /// For models with sliding window attention, this calculates an average based on typical usage
57    pub fn compute_kv_cache_size(&mut self, bytes_per_param: u32) {
58        // Use num_kv_heads if specified (GQA/MQA), otherwise use num_heads (MHA)
59        let kv_heads = self.num_kv_heads.unwrap_or(self.num_heads);
60        let head_dim = self.hidden_dim / self.num_heads;
61
62        // Bytes per token per layer
63        let bytes_per_token_per_layer =
64            2 * kv_heads as u64 * head_dim as u64 * bytes_per_param as u64;
65
66        // If no sliding window, all layers use full attention
67        if self.sliding_window.is_none() || self.num_sliding_layers.is_none() {
68            self.kv_cache_bytes_per_token = bytes_per_token_per_layer * self.num_layers as u64;
69            return;
70        }
71
72        // With sliding window: some layers cap at window size, others grow with sequence
73        let _num_sliding = self.num_sliding_layers.unwrap_or(0);
74
75        // All layers contribute equally per token (sliding window just caps maximum)
76        // At short sequences (< window), all layers grow at same rate
77        // This is correct for the initial growth phase
78        self.kv_cache_bytes_per_token = bytes_per_token_per_layer * self.num_layers as u64;
79    }
80
81    /// Initialize with KV cache size pre-computed
82    pub fn with_kv_cache_size(mut self, bytes_per_param: u32) -> Self {
83        self.compute_kv_cache_size(bytes_per_param);
84        self
85    }
86
87    /// Calculate total KV cache size for a sequence, accounting for sliding window
88    pub fn kv_cache_size_for_sequence(&self, seq_len: u32) -> u64 {
89        // Use num_kv_heads if specified (GQA/MQA), otherwise use num_heads (MHA)
90        let kv_heads = self.num_kv_heads.unwrap_or(self.num_heads);
91        let head_dim = self.hidden_dim / self.num_heads;
92        let bytes_per_token_per_layer = 2 * kv_heads as u64 * head_dim as u64 * 1; // Assuming bytes_per_param
93
94        // No sliding window: simple linear growth
95        if self.sliding_window.is_none() || self.num_sliding_layers.is_none() {
96            return self.kv_cache_bytes_per_token * seq_len as u64;
97        }
98
99        let window = self.sliding_window.unwrap();
100        let num_sliding = self.num_sliding_layers.unwrap_or(0);
101        let num_full = self.num_layers.saturating_sub(num_sliding);
102
103        // Full attention layers: grow linearly with sequence length
104        let full_layers_kv = bytes_per_token_per_layer * num_full as u64 * seq_len as u64;
105
106        // Sliding window layers: capped at window size
107        let sliding_layers_kv =
108            bytes_per_token_per_layer * num_sliding as u64 * seq_len.min(window) as u64;
109
110        full_layers_kv + sliding_layers_kv
111    }
112}