1use crate::{devices::*, ids::ModelId, FerrumError, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub enum ModelType {
10 Llama,
12 Mistral,
14 Qwen,
16 Phi,
18 Gemma,
20 Code(String),
22 Custom(String),
24}
25
26impl std::fmt::Display for ModelType {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 ModelType::Llama => write!(f, "llama"),
30 ModelType::Mistral => write!(f, "mistral"),
31 ModelType::Qwen => write!(f, "qwen"),
32 ModelType::Phi => write!(f, "phi"),
33 ModelType::Gemma => write!(f, "gemma"),
34 ModelType::Code(name) => write!(f, "code-{}", name),
35 ModelType::Custom(name) => write!(f, "custom-{}", name),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelInfo {
43 pub model_id: ModelId,
45 pub model_type: ModelType,
47 pub num_parameters: u64,
49 pub hidden_size: usize,
51 pub num_layers: usize,
53 pub num_heads: usize,
55 pub num_kv_heads: usize,
57 pub vocab_size: usize,
59 pub max_sequence_length: usize,
61 pub dtype: DataType,
63 pub device: Device,
65 pub version: Option<String>,
67 pub license: Option<String>,
69 pub metadata: HashMap<String, serde_json::Value>,
71}
72
73impl ModelInfo {
74 pub fn estimated_size_bytes(&self) -> u64 {
76 let param_size = self.num_parameters * self.dtype.size_bytes() as u64;
78 (param_size as f64 * 1.2) as u64
80 }
81
82 pub fn supports_sequence_length(&self, length: usize) -> bool {
84 length <= self.max_sequence_length
85 }
86
87 pub fn memory_requirements(
89 &self,
90 batch_size: usize,
91 sequence_length: usize,
92 ) -> ModelMemoryRequirements {
93 let param_memory = self.estimated_size_bytes();
94
95 let head_dim = self.hidden_size / self.num_heads;
97 let kv_cache_per_token =
98 self.num_layers * self.num_kv_heads * head_dim * 2 * self.dtype.size_bytes();
99 let kv_cache_memory = (kv_cache_per_token * sequence_length * batch_size) as u64;
100
101 let activation_memory =
103 (self.hidden_size * sequence_length * batch_size * self.dtype.size_bytes()) as u64 * 4;
104
105 ModelMemoryRequirements {
106 parameter_memory: param_memory,
107 kv_cache_memory,
108 activation_memory,
109 total_estimated: param_memory + kv_cache_memory + activation_memory,
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ModelMemoryRequirements {
117 pub parameter_memory: u64,
119 pub kv_cache_memory: u64,
121 pub activation_memory: u64,
123 pub total_estimated: u64,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ModelConfig {
130 pub model_id: ModelId,
132 pub model_path: String,
134 pub model_type: ModelType,
136 pub dtype: DataType,
138 pub device: Device,
140 pub max_batch_size: usize,
142 pub max_sequence_length: usize,
144 pub tensor_parallel_size: Option<usize>,
146 pub pipeline_parallel_size: Option<usize>,
148 pub quantization: Option<QuantizationConfig>,
150 pub use_flash_attention: bool,
152 pub use_paged_attention: bool,
154 pub enable_cuda_graphs: bool,
156 pub extra_config: HashMap<String, serde_json::Value>,
158}
159
160impl ModelConfig {
161 pub fn new(model_id: impl Into<ModelId>, model_path: impl Into<String>) -> Self {
163 Self {
164 model_id: model_id.into(),
165 model_path: model_path.into(),
166 model_type: ModelType::Custom("unknown".to_string()),
167 dtype: DataType::FP16,
168 device: Device::CPU,
169 max_batch_size: 1,
170 max_sequence_length: 2048,
171 tensor_parallel_size: None,
172 pipeline_parallel_size: None,
173 quantization: None,
174 use_flash_attention: false,
175 use_paged_attention: false,
176 enable_cuda_graphs: false,
177 extra_config: HashMap::new(),
178 }
179 }
180
181 pub fn validate(&self) -> Result<()> {
183 if self.model_path.is_empty() {
184 return Err(FerrumError::config("Model path cannot be empty"));
185 }
186
187 if self.max_batch_size == 0 {
188 return Err(FerrumError::config("Max batch size must be positive"));
189 }
190
191 if self.max_sequence_length == 0 {
192 return Err(FerrumError::config("Max sequence length must be positive"));
193 }
194
195 if let Some(tp_size) = self.tensor_parallel_size {
196 if tp_size == 0 {
197 return Err(FerrumError::config("Tensor parallel size must be positive"));
198 }
199 }
200
201 if let Some(pp_size) = self.pipeline_parallel_size {
202 if pp_size == 0 {
203 return Err(FerrumError::config(
204 "Pipeline parallel size must be positive",
205 ));
206 }
207 }
208
209 Ok(())
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub enum QuantizationConfig {
216 GPTQ {
218 bits: u8,
219 group_size: usize,
220 desc_act: bool,
221 },
222 AWQ {
224 bits: u8,
225 zero_point: bool,
226 version: String,
227 },
228 FP8 { e4m3: bool, kv_cache: bool },
230 INT8 { symmetric: bool, per_channel: bool },
232 INT4 { symmetric: bool, group_size: usize },
234 SmoothQuant { alpha: f32, calibration_size: usize },
236}
237
238impl QuantizationConfig {
239 pub fn bits(&self) -> u8 {
241 match self {
242 QuantizationConfig::GPTQ { bits, .. } => *bits,
243 QuantizationConfig::AWQ { bits, .. } => *bits,
244 QuantizationConfig::FP8 { .. } => 8,
245 QuantizationConfig::INT8 { .. } => 8,
246 QuantizationConfig::INT4 { .. } => 4,
247 QuantizationConfig::SmoothQuant { .. } => 8,
248 }
249 }
250
251 pub fn is_high_accuracy(&self) -> bool {
253 match self {
254 QuantizationConfig::FP8 { .. } => true,
255 QuantizationConfig::INT8 { .. } => true,
256 QuantizationConfig::SmoothQuant { .. } => true,
257 _ => false,
258 }
259 }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct TokenUsage {
265 pub prompt_tokens: usize,
267 pub completion_tokens: usize,
269 pub total_tokens: usize,
271}
272
273impl TokenUsage {
274 pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
276 Self {
277 prompt_tokens,
278 completion_tokens,
279 total_tokens: prompt_tokens + completion_tokens,
280 }
281 }
282
283 pub fn add_completion_tokens(&mut self, tokens: usize) {
285 self.completion_tokens += tokens;
286 self.total_tokens = self.prompt_tokens + self.completion_tokens;
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct RopeScaling {
293 pub scaling_type: String,
295 pub factor: f32,
297}
298
299#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
301pub enum NormType {
302 LayerNorm,
304 RMSNorm,
306}
307
308#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
310pub enum Activation {
311 GELU,
313 SiLU,
315 ReLU,
317 Swish,
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct AttentionConfig {
324 pub attention_bias: bool,
326 pub sliding_window: Option<usize>,
328}
329
330impl Default for AttentionConfig {
331 fn default() -> Self {
332 Self {
333 attention_bias: false,
334 sliding_window: None,
335 }
336 }
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub enum ModelSource {
342 Local(String),
344 HuggingFace {
346 repo_id: String,
347 revision: Option<String>,
348 cache_dir: Option<String>,
349 },
350 Url {
352 url: String,
353 headers: HashMap<String, String>,
354 },
355 S3 {
357 bucket: String,
358 key: String,
359 region: Option<String>,
360 endpoint: Option<String>,
361 },
362}