1use crate::{devices::*, ids::ModelId, FerrumError, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub enum ModelType {
10 Llama,
12 Mistral,
14 Qwen,
16 Phi,
18 Gemma,
20 Code(String),
22 Embedding,
24 Clip,
26 Custom(String),
28}
29
30impl std::fmt::Display for ModelType {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 ModelType::Llama => write!(f, "llama"),
34 ModelType::Mistral => write!(f, "mistral"),
35 ModelType::Qwen => write!(f, "qwen"),
36 ModelType::Phi => write!(f, "phi"),
37 ModelType::Gemma => write!(f, "gemma"),
38 ModelType::Embedding => write!(f, "embedding"),
39 ModelType::Clip => write!(f, "clip"),
40 ModelType::Code(name) => write!(f, "code-{}", name),
41 ModelType::Custom(name) => write!(f, "custom-{}", name),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelInfo {
49 pub model_id: ModelId,
51 pub model_type: ModelType,
53 pub num_parameters: u64,
55 pub hidden_size: usize,
57 pub num_layers: usize,
59 pub num_heads: usize,
61 pub num_kv_heads: usize,
63 pub vocab_size: usize,
65 pub max_sequence_length: usize,
67 pub dtype: DataType,
69 pub device: Device,
71 pub version: Option<String>,
73 pub license: Option<String>,
75 pub metadata: HashMap<String, serde_json::Value>,
77}
78
79impl ModelInfo {
80 pub fn estimated_size_bytes(&self) -> u64 {
82 let param_size = self.num_parameters * self.dtype.size_bytes() as u64;
84 (param_size as f64 * 1.2) as u64
86 }
87
88 pub fn supports_sequence_length(&self, length: usize) -> bool {
90 length <= self.max_sequence_length
91 }
92
93 pub fn memory_requirements(
95 &self,
96 batch_size: usize,
97 sequence_length: usize,
98 ) -> ModelMemoryRequirements {
99 let param_memory = self.estimated_size_bytes();
100
101 let head_dim = self.hidden_size / self.num_heads;
103 let kv_cache_per_token =
104 self.num_layers * self.num_kv_heads * head_dim * 2 * self.dtype.size_bytes();
105 let kv_cache_memory = (kv_cache_per_token * sequence_length * batch_size) as u64;
106
107 let activation_memory =
109 (self.hidden_size * sequence_length * batch_size * self.dtype.size_bytes()) as u64 * 4;
110
111 ModelMemoryRequirements {
112 parameter_memory: param_memory,
113 kv_cache_memory,
114 activation_memory,
115 total_estimated: param_memory + kv_cache_memory + activation_memory,
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ModelMemoryRequirements {
123 pub parameter_memory: u64,
125 pub kv_cache_memory: u64,
127 pub activation_memory: u64,
129 pub total_estimated: u64,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ModelConfig {
136 pub model_id: ModelId,
138 pub model_path: String,
140 pub model_type: ModelType,
142 pub dtype: DataType,
144 pub device: Device,
146 pub max_batch_size: usize,
148 pub max_sequence_length: usize,
150 pub tensor_parallel_size: Option<usize>,
152 pub pipeline_parallel_size: Option<usize>,
154 pub quantization: Option<QuantizationConfig>,
156 pub use_flash_attention: bool,
158 pub use_paged_attention: bool,
160 pub enable_cuda_graphs: bool,
162 pub extra_config: HashMap<String, serde_json::Value>,
164}
165
166impl ModelConfig {
167 pub fn new(model_id: impl Into<ModelId>, model_path: impl Into<String>) -> Self {
169 Self {
170 model_id: model_id.into(),
171 model_path: model_path.into(),
172 model_type: ModelType::Custom("unknown".to_string()),
173 dtype: DataType::FP16,
174 device: Device::CPU,
175 max_batch_size: 1,
176 max_sequence_length: 2048,
177 tensor_parallel_size: None,
178 pipeline_parallel_size: None,
179 quantization: None,
180 use_flash_attention: false,
181 use_paged_attention: false,
182 enable_cuda_graphs: false,
183 extra_config: HashMap::new(),
184 }
185 }
186
187 pub fn validate(&self) -> Result<()> {
189 if self.model_path.is_empty() {
190 return Err(FerrumError::config("Model path cannot be empty"));
191 }
192
193 if self.max_batch_size == 0 {
194 return Err(FerrumError::config("Max batch size must be positive"));
195 }
196
197 if self.max_sequence_length == 0 {
198 return Err(FerrumError::config("Max sequence length must be positive"));
199 }
200
201 if let Some(tp_size) = self.tensor_parallel_size {
202 if tp_size == 0 {
203 return Err(FerrumError::config("Tensor parallel size must be positive"));
204 }
205 }
206
207 if let Some(pp_size) = self.pipeline_parallel_size {
208 if pp_size == 0 {
209 return Err(FerrumError::config(
210 "Pipeline parallel size must be positive",
211 ));
212 }
213 }
214
215 Ok(())
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub enum QuantizationConfig {
222 GPTQ {
224 bits: u8,
225 group_size: usize,
226 desc_act: bool,
227 },
228 AWQ {
230 bits: u8,
231 zero_point: bool,
232 version: String,
233 },
234 FP8 { e4m3: bool, kv_cache: bool },
236 INT8 { symmetric: bool, per_channel: bool },
238 INT4 { symmetric: bool, group_size: usize },
240 SmoothQuant { alpha: f32, calibration_size: usize },
242}
243
244impl QuantizationConfig {
245 pub fn bits(&self) -> u8 {
247 match self {
248 QuantizationConfig::GPTQ { bits, .. } => *bits,
249 QuantizationConfig::AWQ { bits, .. } => *bits,
250 QuantizationConfig::FP8 { .. } => 8,
251 QuantizationConfig::INT8 { .. } => 8,
252 QuantizationConfig::INT4 { .. } => 4,
253 QuantizationConfig::SmoothQuant { .. } => 8,
254 }
255 }
256
257 pub fn is_high_accuracy(&self) -> bool {
259 match self {
260 QuantizationConfig::FP8 { .. } => true,
261 QuantizationConfig::INT8 { .. } => true,
262 QuantizationConfig::SmoothQuant { .. } => true,
263 _ => false,
264 }
265 }
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct TokenUsage {
271 pub prompt_tokens: usize,
273 pub completion_tokens: usize,
275 pub total_tokens: usize,
277}
278
279impl TokenUsage {
280 pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
282 Self {
283 prompt_tokens,
284 completion_tokens,
285 total_tokens: prompt_tokens + completion_tokens,
286 }
287 }
288
289 pub fn add_completion_tokens(&mut self, tokens: usize) {
291 self.completion_tokens += tokens;
292 self.total_tokens = self.prompt_tokens + self.completion_tokens;
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct RopeScaling {
299 pub scaling_type: String,
301 pub factor: f32,
303}
304
305#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
307pub enum NormType {
308 LayerNorm,
310 RMSNorm,
312}
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
316pub enum Activation {
317 GELU,
319 SiLU,
321 ReLU,
323 Swish,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct AttentionConfig {
330 pub attention_bias: bool,
332 pub sliding_window: Option<usize>,
334}
335
336impl Default for AttentionConfig {
337 fn default() -> Self {
338 Self {
339 attention_bias: false,
340 sliding_window: None,
341 }
342 }
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
347pub enum ModelSource {
348 Local(String),
350 HuggingFace {
352 repo_id: String,
353 revision: Option<String>,
354 cache_dir: Option<String>,
355 },
356 Url {
358 url: String,
359 headers: HashMap<String, String>,
360 },
361 S3 {
363 bucket: String,
364 key: String,
365 region: Option<String>,
366 endpoint: Option<String>,
367 },
368}